{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module      :  Mcmc.Internal.Gamma
-- Description :  Generalized gamma function for automatic differentiation
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Tue Jul 13 12:53:09 2021.
--
-- The code is taken from "Numeric.SpecFunctions".
module Mcmc.Internal.SpecFunctions
  ( logGammaG,
    logFactorialG,
  )
where

import Data.Typeable
import qualified Data.Vector as VB
import Numeric.Polynomial
import Numeric.SpecFunctions
import Unsafe.Coerce

mSqrtEpsG :: RealFloat a => a
mSqrtEpsG :: forall a. RealFloat a => a
mSqrtEpsG = a
1.4901161193847656e-8

mEulerMascheroniG :: RealFloat a => a
mEulerMascheroniG :: forall a. RealFloat a => a
mEulerMascheroniG = a
0.5772156649015328606065121

-- | Generalized version of the log gamma distribution. See
-- 'Numeric.SpecFunctions.logGamma'.
logGammaG :: (Typeable a, RealFloat a) => a -> a
logGammaG :: forall a. (Typeable a, RealFloat a) => a -> a
logGammaG a
z
  | forall a. Typeable a => a -> TypeRep
typeOf a
z forall a. Eq a => a -> a -> Bool
== forall a. Typeable a => a -> TypeRep
typeOf (Double
0 :: Double) = forall a b. a -> b
unsafeCoerce Double -> Double
logGamma a
z
  | Bool
otherwise = forall a. RealFloat a => a -> a
logGammaNonDouble a
z
{-# SPECIALIZE logGammaG :: Double -> Double #-}

-- See 'Numeric.SpecFunctions.logGamma'.
logGammaNonDouble :: RealFloat a => a -> a
logGammaNonDouble :: forall a. RealFloat a => a -> a
logGammaNonDouble a
z
  | a
z forall a. Ord a => a -> a -> Bool
<= a
0 = a
1 forall a. Fractional a => a -> a -> a
/ a
0
  | a
z forall a. Ord a => a -> a -> Bool
< forall a. RealFloat a => a
mSqrtEpsG = forall a. Floating a => a -> a
log (a
1 forall a. Fractional a => a -> a -> a
/ a
z forall a. Num a => a -> a -> a
- forall a. RealFloat a => a
mEulerMascheroniG)
  | a
z forall a. Ord a => a -> a -> Bool
< a
0.5 = forall a. RealFloat a => a -> a -> a
lgamma1_15G a
z (a
z forall a. Num a => a -> a -> a
- a
1) forall a. Num a => a -> a -> a
- forall a. Floating a => a -> a
log a
z
  | a
z forall a. Ord a => a -> a -> Bool
< a
1 = forall a. RealFloat a => a -> a -> a
lgamma15_2G a
z (a
z forall a. Num a => a -> a -> a
- a
1) forall a. Num a => a -> a -> a
- forall a. Floating a => a -> a
log a
z
  | a
z forall a. Ord a => a -> a -> Bool
<= a
1.5 = forall a. RealFloat a => a -> a -> a
lgamma1_15G (a
z forall a. Num a => a -> a -> a
- a
1) (a
z forall a. Num a => a -> a -> a
- a
2)
  | a
z forall a. Ord a => a -> a -> Bool
< a
2 = forall a. RealFloat a => a -> a -> a
lgamma15_2G (a
z forall a. Num a => a -> a -> a
- a
1) (a
z forall a. Num a => a -> a -> a
- a
2)
  | a
z forall a. Ord a => a -> a -> Bool
< a
15 = forall a. RealFloat a => a -> a
lgammaSmallG a
z
  | Bool
otherwise = forall a. RealFloat a => a -> a
lanczosApproxG a
z

lgamma1_15G :: RealFloat a => a -> a -> a
lgamma1_15G :: forall a. RealFloat a => a -> a -> a
lgamma1_15G a
zm1 a
zm2 =
  a
r forall a. Num a => a -> a -> a
* a
y
    forall a. Num a => a -> a -> a
+ a
r
      forall a. Num a => a -> a -> a
* ( forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial a
zm1 forall a. RealFloat a => Vector a
tableLogGamma_1_15PG
            forall a. Fractional a => a -> a -> a
/ forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial a
zm1 forall a. RealFloat a => Vector a
tableLogGamma_1_15QG
        )
  where
    r :: a
r = a
zm1 forall a. Num a => a -> a -> a
* a
zm2
    y :: a
y = a
0.52815341949462890625

tableLogGamma_1_15PG :: RealFloat a => VB.Vector a
tableLogGamma_1_15PG :: forall a. RealFloat a => Vector a
tableLogGamma_1_15PG =
  forall a. [a] -> Vector a
VB.fromList
    [ a
0.490622454069039543534e-1,
      -a
0.969117530159521214579e-1,
      -a
0.414983358359495381969e0,
      -a
0.406567124211938417342e0,
      -a
0.158413586390692192217e0,
      -a
0.240149820648571559892e-1,
      -a
0.100346687696279557415e-2
    ]
{-# NOINLINE tableLogGamma_1_15PG #-}

tableLogGamma_1_15QG :: RealFloat a => VB.Vector a
tableLogGamma_1_15QG :: forall a. RealFloat a => Vector a
tableLogGamma_1_15QG =
  forall a. [a] -> Vector a
VB.fromList
    [ a
1,
      a
0.302349829846463038743e1,
      a
0.348739585360723852576e1,
      a
0.191415588274426679201e1,
      a
0.507137738614363510846e0,
      a
0.577039722690451849648e-1,
      a
0.195768102601107189171e-2
    ]
{-# NOINLINE tableLogGamma_1_15QG #-}

lgamma15_2G :: RealFloat a => a -> a -> a
lgamma15_2G :: forall a. RealFloat a => a -> a -> a
lgamma15_2G a
zm1 a
zm2 =
  a
r forall a. Num a => a -> a -> a
* a
y
    forall a. Num a => a -> a -> a
+ a
r
      forall a. Num a => a -> a -> a
* ( forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial (-a
zm2) forall a. RealFloat a => Vector a
tableLogGamma_15_2PG
            forall a. Fractional a => a -> a -> a
/ forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial (-a
zm2) forall a. RealFloat a => Vector a
tableLogGamma_15_2QG
        )
  where
    r :: a
r = a
zm1 forall a. Num a => a -> a -> a
* a
zm2
    y :: a
y = a
0.452017307281494140625

tableLogGamma_15_2PG :: RealFloat a => VB.Vector a
tableLogGamma_15_2PG :: forall a. RealFloat a => Vector a
tableLogGamma_15_2PG =
  forall a. [a] -> Vector a
VB.fromList
    [ -a
0.292329721830270012337e-1,
      a
0.144216267757192309184e0,
      -a
0.142440390738631274135e0,
      a
0.542809694055053558157e-1,
      -a
0.850535976868336437746e-2,
      a
0.431171342679297331241e-3
    ]
{-# NOINLINE tableLogGamma_15_2PG #-}

tableLogGamma_15_2QG :: RealFloat a => VB.Vector a
tableLogGamma_15_2QG :: forall a. RealFloat a => Vector a
tableLogGamma_15_2QG =
  forall a. [a] -> Vector a
VB.fromList
    [ a
1,
      -a
0.150169356054485044494e1,
      a
0.846973248876495016101e0,
      -a
0.220095151814995745555e0,
      a
0.25582797155975869989e-1,
      -a
0.100666795539143372762e-2,
      -a
0.827193521891290553639e-6
    ]
{-# NOINLINE tableLogGamma_15_2QG #-}

lgammaSmallG :: RealFloat a => a -> a
lgammaSmallG :: forall a. RealFloat a => a -> a
lgammaSmallG = forall a. RealFloat a => a -> a -> a
go a
0
  where
    go :: t -> t -> t
go t
acc t
z
      | t
z forall a. Ord a => a -> a -> Bool
< t
3 = t
acc forall a. Num a => a -> a -> a
+ forall a. RealFloat a => a -> a
lgamma2_3G t
z
      | Bool
otherwise = t -> t -> t
go (t
acc forall a. Num a => a -> a -> a
+ forall a. Floating a => a -> a
log t
zm1) t
zm1
      where
        zm1 :: t
zm1 = t
z forall a. Num a => a -> a -> a
- t
1

lgamma2_3G :: RealFloat a => a -> a
lgamma2_3G :: forall a. RealFloat a => a -> a
lgamma2_3G a
z =
  a
r forall a. Num a => a -> a -> a
* a
y
    forall a. Num a => a -> a -> a
+ a
r
      forall a. Num a => a -> a -> a
* ( forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial a
zm2 forall a. RealFloat a => Vector a
tableLogGamma_2_3PG
            forall a. Fractional a => a -> a -> a
/ forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial a
zm2 forall a. RealFloat a => Vector a
tableLogGamma_2_3QG
        )
  where
    r :: a
r = a
zm2 forall a. Num a => a -> a -> a
* (a
z forall a. Num a => a -> a -> a
+ a
1)
    zm2 :: a
zm2 = a
z forall a. Num a => a -> a -> a
- a
2
    y :: a
y = a
0.158963680267333984375e0

tableLogGamma_2_3PG :: RealFloat a => VB.Vector a
tableLogGamma_2_3PG :: forall a. RealFloat a => Vector a
tableLogGamma_2_3PG =
  forall a. [a] -> Vector a
VB.fromList
    [ -a
0.180355685678449379109e-1,
      a
0.25126649619989678683e-1,
      a
0.494103151567532234274e-1,
      a
0.172491608709613993966e-1,
      -a
0.259453563205438108893e-3,
      -a
0.541009869215204396339e-3,
      -a
0.324588649825948492091e-4
    ]
{-# NOINLINE tableLogGamma_2_3PG #-}

tableLogGamma_2_3QG :: RealFloat a => VB.Vector a
tableLogGamma_2_3QG :: forall a. RealFloat a => Vector a
tableLogGamma_2_3QG =
  forall a. [a] -> Vector a
VB.fromList
    [ a
1,
      a
0.196202987197795200688e1,
      a
0.148019669424231326694e1,
      a
0.541391432071720958364e0,
      a
0.988504251128010129477e-1,
      a
0.82130967464889339326e-2,
      a
0.224936291922115757597e-3,
      -a
0.223352763208617092964e-6
    ]
{-# NOINLINE tableLogGamma_2_3QG #-}

lanczosApproxG :: RealFloat a => a -> a
lanczosApproxG :: forall a. RealFloat a => a -> a
lanczosApproxG a
z =
  (forall a. Floating a => a -> a
log (a
z forall a. Num a => a -> a -> a
+ a
g forall a. Num a => a -> a -> a
- a
0.5) forall a. Num a => a -> a -> a
- a
1) forall a. Num a => a -> a -> a
* (a
z forall a. Num a => a -> a -> a
- a
0.5)
    forall a. Num a => a -> a -> a
+ forall a. Floating a => a -> a
log (forall a. RealFloat a => Vector (a, a) -> a -> a
evalRatioG forall a. RealFloat a => Vector (a, a)
tableLanczosG a
z)
  where
    g :: a
g = a
6.024680040776729583740234375

tableLanczosG :: RealFloat a => VB.Vector (a, a)
tableLanczosG :: forall a. RealFloat a => Vector (a, a)
tableLanczosG =
  forall a. [a] -> Vector a
VB.fromList
    [ (a
56906521.91347156388090791033559122686859, a
0),
      (a
103794043.1163445451906271053616070238554, a
39916800),
      (a
86363131.28813859145546927288977868422342, a
120543840),
      (a
43338889.32467613834773723740590533316085, a
150917976),
      (a
14605578.08768506808414169982791359218571, a
105258076),
      (a
3481712.15498064590882071018964774556468, a
45995730),
      (a
601859.6171681098786670226533699352302507, a
13339535),
      (a
75999.29304014542649875303443598909137092, a
2637558),
      (a
6955.999602515376140356310115515198987526, a
357423),
      (a
449.9445569063168119446858607650988409623, a
32670),
      (a
19.51992788247617482847860966235652136208, a
1925),
      (a
0.5098416655656676188125178644804694509993, a
66),
      (a
0.006061842346248906525783753964555936883222, a
1)
    ]
{-# NOINLINE tableLanczosG #-}

data LG a = LG !a !a

evalRatioG :: RealFloat a => VB.Vector (a, a) -> a -> a
evalRatioG :: forall a. RealFloat a => Vector (a, a) -> a -> a
evalRatioG Vector (a, a)
coef a
x
  | a
x forall a. Ord a => a -> a -> Bool
> a
1 = forall {a}. Fractional a => LG a -> a
fini forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b -> a) -> a -> Vector b -> a
VB.foldl' LG a -> (a, a) -> LG a
stepL (forall a. a -> a -> LG a
LG a
0 a
0) Vector (a, a)
coef
  | Bool
otherwise = forall {a}. Fractional a => LG a -> a
fini forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b -> b) -> b -> Vector a -> b
VB.foldr' (a, a) -> LG a -> LG a
stepR (forall a. a -> a -> LG a
LG a
0 a
0) Vector (a, a)
coef
  where
    fini :: LG a -> a
fini (LG a
num a
den) = a
num forall a. Fractional a => a -> a -> a
/ a
den
    stepR :: (a, a) -> LG a -> LG a
stepR (a
a, a
b) (LG a
num a
den) = forall a. a -> a -> LG a
LG (a
num forall a. Num a => a -> a -> a
* a
x forall a. Num a => a -> a -> a
+ a
a) (a
den forall a. Num a => a -> a -> a
* a
x forall a. Num a => a -> a -> a
+ a
b)
    stepL :: LG a -> (a, a) -> LG a
stepL (LG a
num a
den) (a
a, a
b) = forall a. a -> a -> LG a
LG (a
num forall a. Num a => a -> a -> a
* a
rx forall a. Num a => a -> a -> a
+ a
a) (a
den forall a. Num a => a -> a -> a
* a
rx forall a. Num a => a -> a -> a
+ a
b)
    rx :: a
rx = forall a. Fractional a => a -> a
recip a
x

-- | Generalized version of the log factorial function. See
-- 'Numeric.SpecFunctions.logFactorial'.
logFactorialG :: forall a b. (Integral a, RealFloat b, Typeable b) => a -> b
logFactorialG :: forall a b. (Integral a, RealFloat b, Typeable b) => a -> b
logFactorialG a
n
  | forall a. Typeable a => a -> TypeRep
typeOf (forall a. HasCallStack => a
undefined :: b) forall a. Eq a => a -> a -> Bool
== forall a. Typeable a => a -> TypeRep
typeOf (Double
0 :: Double) = forall a b. a -> b
unsafeCoerce forall a b. (a -> b) -> a -> b
$ forall a. Integral a => a -> Double
logFactorial a
n
  | Bool
otherwise = forall a b. (Integral a, RealFloat b) => a -> b
logFactorialNonDouble a
n
{-# SPECIALIZE logFactorialG :: Int -> Double #-}

logFactorialNonDouble :: (Integral a, RealFloat b) => a -> b
logFactorialNonDouble :: forall a b. (Integral a, RealFloat b) => a -> b
logFactorialNonDouble a
n
  | a
n forall a. Ord a => a -> a -> Bool
< a
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"logFactorialNonDouble: Negative input."
  | a
n forall a. Ord a => a -> a -> Bool
<= a
170 = forall a. Floating a => a -> a
log forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> Int -> a
VB.unsafeIndex forall a. RealFloat a => Vector a
factorialTable (forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n)
  | a
n forall a. Ord a => a -> a -> Bool
< a
1500 = b
stirling forall a. Num a => a -> a -> a
+ b
rx forall a. Num a => a -> a -> a
* ((b
1 forall a. Fractional a => a -> a -> a
/ b
12) forall a. Num a => a -> a -> a
- (b
1 forall a. Fractional a => a -> a -> a
/ b
360) forall a. Num a => a -> a -> a
* b
rx forall a. Num a => a -> a -> a
* b
rx)
  | Bool
otherwise = b
stirling forall a. Num a => a -> a -> a
+ (b
1 forall a. Fractional a => a -> a -> a
/ b
12) forall a. Num a => a -> a -> a
* b
rx
  where
    stirling :: b
stirling = (b
x forall a. Num a => a -> a -> a
- b
0.5) forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log b
x forall a. Num a => a -> a -> a
- b
x forall a. Num a => a -> a -> a
+ forall a. RealFloat a => a
mLnSqrt2Pi
    x :: b
x = forall a b. (Integral a, Num b) => a -> b
fromIntegral a
n forall a. Num a => a -> a -> a
+ b
1
    rx :: b
rx = forall a. Fractional a => a -> a
recip b
x
{-# SPECIALIZE logFactorialNonDouble :: RealFloat a => Int -> a #-}

mLnSqrt2Pi :: RealFloat a => a
mLnSqrt2Pi :: forall a. RealFloat a => a
mLnSqrt2Pi = a
0.9189385332046727417803297364056176398613974736377834128171
{-# INLINE mLnSqrt2Pi #-}

factorialTable :: RealFloat a => VB.Vector a
{-# NOINLINE factorialTable #-}
factorialTable :: forall a. RealFloat a => Vector a
factorialTable =
  forall a. Int -> [a] -> Vector a
VB.fromListN
    Int
171
    [ a
1.0,
      a
1.0,
      a
2.0,
      a
6.0,
      a
24.0,
      a
120.0,
      a
720.0,
      a
5040.0,
      a
40320.0,
      a
362880.0,
      a
3628800.0,
      a
3.99168e7,
      a
4.790016e8,
      a
6.2270208e9,
      a
8.71782912e10,
      a
1.307674368e12,
      a
2.0922789888e13,
      a
3.55687428096e14,
      a
6.402373705728e15,
      a
1.21645100408832e17,
      a
2.43290200817664e18,
      a
5.109094217170944e19,
      a
1.1240007277776077e21,
      a
2.5852016738884974e22,
      a
6.204484017332394e23,
      a
1.5511210043330984e25,
      a
4.032914611266056e26,
      a
1.0888869450418352e28,
      a
3.0488834461171384e29,
      a
8.841761993739702e30,
      a
2.6525285981219103e32,
      a
8.222838654177922e33,
      a
2.631308369336935e35,
      a
8.683317618811886e36,
      a
2.9523279903960412e38,
      a
1.0333147966386144e40,
      a
3.719933267899012e41,
      a
1.3763753091226343e43,
      a
5.23022617466601e44,
      a
2.0397882081197442e46,
      a
8.159152832478977e47,
      a
3.3452526613163803e49,
      a
1.4050061177528798e51,
      a
6.041526306337383e52,
      a
2.6582715747884485e54,
      a
1.1962222086548019e56,
      a
5.5026221598120885e57,
      a
2.5862324151116818e59,
      a
1.2413915592536073e61,
      a
6.082818640342675e62,
      a
3.0414093201713376e64,
      a
1.5511187532873822e66,
      a
8.065817517094388e67,
      a
4.2748832840600255e69,
      a
2.308436973392414e71,
      a
1.2696403353658275e73,
      a
7.109985878048634e74,
      a
4.0526919504877214e76,
      a
2.3505613312828785e78,
      a
1.386831185456898e80,
      a
8.32098711274139e81,
      a
5.075802138772247e83,
      a
3.146997326038793e85,
      a
1.9826083154044399e87,
      a
1.2688693218588415e89,
      a
8.24765059208247e90,
      a
5.44344939077443e92,
      a
3.647111091818868e94,
      a
2.4800355424368305e96,
      a
1.711224524281413e98,
      a
1.197857166996989e100,
      a
8.504785885678623e101,
      a
6.1234458376886085e103,
      a
4.470115461512684e105,
      a
3.307885441519386e107,
      a
2.4809140811395396e109,
      a
1.88549470166605e111,
      a
1.4518309202828586e113,
      a
1.1324281178206297e115,
      a
8.946182130782974e116,
      a
7.15694570462638e118,
      a
5.797126020747368e120,
      a
4.753643337012841e122,
      a
3.9455239697206583e124,
      a
3.314240134565353e126,
      a
2.81710411438055e128,
      a
2.422709538367273e130,
      a
2.1077572983795275e132,
      a
1.8548264225739844e134,
      a
1.650795516090846e136,
      a
1.4857159644817613e138,
      a
1.352001527678403e140,
      a
1.2438414054641305e142,
      a
1.1567725070816416e144,
      a
1.087366156656743e146,
      a
1.0329978488239058e148,
      a
9.916779348709496e149,
      a
9.619275968248211e151,
      a
9.426890448883246e153,
      a
9.332621544394413e155,
      a
9.332621544394415e157,
      a
9.425947759838358e159,
      a
9.614466715035125e161,
      a
9.902900716486179e163,
      a
1.0299016745145626e166,
      a
1.0813967582402908e168,
      a
1.1462805637347082e170,
      a
1.2265202031961378e172,
      a
1.3246418194518288e174,
      a
1.4438595832024934e176,
      a
1.5882455415227428e178,
      a
1.7629525510902446e180,
      a
1.974506857221074e182,
      a
2.2311927486598134e184,
      a
2.543559733472187e186,
      a
2.9250936934930154e188,
      a
3.393108684451898e190,
      a
3.9699371608087206e192,
      a
4.68452584975429e194,
      a
5.574585761207606e196,
      a
6.689502913449126e198,
      a
8.094298525273443e200,
      a
9.875044200833601e202,
      a
1.214630436702533e205,
      a
1.5061417415111406e207,
      a
1.8826771768889257e209,
      a
2.372173242880047e211,
      a
3.0126600184576594e213,
      a
3.856204823625804e215,
      a
4.974504222477286e217,
      a
6.466855489220473e219,
      a
8.471580690878819e221,
      a
1.1182486511960041e224,
      a
1.4872707060906857e226,
      a
1.9929427461615188e228,
      a
2.6904727073180504e230,
      a
3.6590428819525483e232,
      a
5.012888748274991e234,
      a
6.917786472619488e236,
      a
9.615723196941088e238,
      a
1.3462012475717523e241,
      a
1.898143759076171e243,
      a
2.6953641378881624e245,
      a
3.8543707171800725e247,
      a
5.5502938327393044e249,
      a
8.047926057471992e251,
      a
1.1749972043909107e254,
      a
1.7272458904546386e256,
      a
2.5563239178728654e258,
      a
3.808922637630569e260,
      a
5.713383956445854e262,
      a
8.62720977423324e264,
      a
1.3113358856834524e267,
      a
2.0063439050956823e269,
      a
3.0897696138473508e271,
      a
4.789142901463393e273,
      a
7.471062926282894e275,
      a
1.1729568794264143e278,
      a
1.8532718694937346e280,
      a
2.946702272495038e282,
      a
4.714723635992061e284,
      a
7.590705053947218e286,
      a
1.2296942187394494e289,
      a
2.0044015765453023e291,
      a
3.287218585534296e293,
      a
5.423910666131589e295,
      a
9.003691705778436e297,
      a
1.5036165148649988e300,
      a
2.526075744973198e302,
      a
4.269068009004705e304,
      a
7.257415615307998e306
    ]