-- | See Shaw, Ewart [\"Hypergeometric Functions and CDFs in J\"](https://www.jsoftware.com/papers/jhyper.pdf).
module Math.Hypergeometric ( hypergeometric
                           , euler
                           , erf
                           , ncdf
                           ) where

import           Data.Functor ((<$>))

risingFactorial :: Num a => a -> Int -> a
risingFactorial :: forall a. Num a => a -> Int -> a
risingFactorial a
_ Int
0 = a
1
risingFactorial a
a Int
n = (a
a a -> a -> a
forall a. Num a => a -> a -> a
+ Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n a -> a -> a
forall a. Num a => a -> a -> a
- a
1) a -> a -> a
forall a. Num a => a -> a -> a
* a -> Int -> a
forall a. Num a => a -> Int -> a
risingFactorial a
a (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)

factorial :: Num a => Int -> a
factorial :: forall a. Num a => Int -> a
factorial Int
n = [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> a) -> [Int] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
1..Int
n])

{-# SPECIALIZE ncdf :: Double -> Double #-}
{-# SPECIALIZE ncdf :: Float -> Float #-}
-- | CDF of the standard normal \( N(0,1) \)
ncdf :: (Ord a, Floating a) => a -> a
ncdf :: forall a. (Ord a, Floating a) => a -> a
ncdf a
z = (a
1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
2) a -> a -> a
forall a. Num a => a -> a -> a
* (a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. (Ord a, Floating a) => a -> a
erf (a
z a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt a
2))

{-# SPECIALIZE erf :: Double -> Double #-}
{-# SPECIALIZE erf :: Float -> Float #-}
-- | [erf](https://mathworld.wolfram.com/Erf.html)
erf :: (Ord a, Floating a) => a -> a
erf :: forall a. (Ord a, Floating a) => a -> a
erf a
z = (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a
z a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
exp (-(a
za -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int))) a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt a
forall a. Floating a => a
pi) a -> a -> a
forall a. Num a => a -> a -> a
* [a] -> [a] -> a -> a
forall a. (Ord a, Fractional a) => [a] -> [a] -> a -> a
hypergeometric [a
1] [a
3a -> a -> a
forall a. Fractional a => a -> a -> a
/a
2] (a
za -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int))

{-# SPECIALIZE euler :: Double -> Double -> Double -> Double -> Double #-}
{-# SPECIALIZE euler :: Float -> Float -> Float -> Float -> Float #-}
-- | Euler's transform:
--
-- \( \displaystyle _2F_1(a,b;c;z) = (1-z)^{-a} {}_2F_1\left(a,c-b;c;\frac{z}{z-1}\right) \)
--
-- Koekoek, Roelef and Swarttouw, René F. [The Askey-scheme of hypergeometric orthogonal polynomials and its q-analogue](https://arxiv.org/abs/math/9602214).
--
-- @since 0.1.5.0
euler :: (Ord a, Floating a)
      => a -- ^ \(a\)
      -> a -- ^ \(b\)
      -> a -- ^ \(c\)
      -> a -- ^ \(z\)
      -> a
euler :: forall a. (Ord a, Floating a) => a -> a -> a -> a -> a
euler a
a a
b a
c a
z = [a] -> [a] -> a -> a
forall a. (Ord a, Fractional a) => [a] -> [a] -> a -> a
hypergeometric [a
a, a
ca -> a -> a
forall a. Num a => a -> a -> a
-a
b] [a
c] (a
za -> a -> a
forall a. Fractional a => a -> a -> a
/a
za -> a -> a
forall a. Num a => a -> a -> a
-a
1)

{-# SPECIALIZE hypergeometric :: [Double] -> [Double] -> Double -> Double #-}
{-# SPECIALIZE hypergeometric :: [Float] -> [Float] -> Float -> Float #-}
-- | \( _pF_q(a_1,\ldots,a_p;b_1,\ldots,b_q;z) = \displaystyle\sum_{n=0}^\infty\frac{(a_1)_n\cdots(a_p)_n}{(b_1)_b\cdots(b_q)_n}\frac{z^n}{n!} \)
--
-- The radius of convergence is
--
-- \( \rho = \begin{cases} \infty & \text{if} & p<q+1 \\ 1 & \text{if} & p=q+1 \\ 0 & \text{if} & p>q+1 \\ \end{cases} \)
--
-- This iterates until the result stabilizes.
hypergeometric :: (Ord a, Fractional a)
               => [a] -- ^ \( a_1,\ldots,a_p \)
               -> [a] -- ^ \( b_1,\ldots,b_q \)
               -> a -- ^ \( z \)
               -> a
hypergeometric :: forall a. (Ord a, Fractional a) => [a] -> [a] -> a -> a
hypergeometric [a]
as [a]
bs a
z = [a] -> a
forall a. (Eq a, Num a) => [a] -> a
sumUntilEq
    [ ([a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> Int -> a
forall a. Num a => a -> Int -> a
`risingFactorial` Int
n) [a]
as) a -> a -> a
forall a. Fractional a => a -> a -> a
/ [a] -> a
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((a -> a) -> [a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> Int -> a
forall a. Num a => a -> Int -> a
`risingFactorial` Int
n) [a]
bs)) a -> a -> a
forall a. Num a => a -> a -> a
* (a
 a -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
n) a -> a -> a
forall a. Fractional a => a -> a -> a
/ Int -> a
forall a. Num a => Int -> a
factorial Int
n | Int
n <- [Int
0..] ]
  where
    zϵ :: a
 = a -> a
forall {a}. a -> a
𝜌 a
z
    𝜌 :: a -> a
𝜌 = case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
as) ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
bsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) of
        Ordering
LT -> a -> a
forall {a}. a -> a
id
        Ordering
EQ -> if a
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
1 then [Char] -> a -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Outside the radius of convergence." else a -> a
forall {a}. a -> a
id
        Ordering
GT -> if a
z a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
0 then [Char] -> a -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Outside the radius of convergence." else a -> a
forall {a}. a -> a
id

sumUntilEq :: (Eq a, Num a) => [a] -> a
sumUntilEq :: forall a. (Eq a, Num a) => [a] -> a
sumUntilEq = a -> [a] -> a
forall a. (Eq a, Num a) => a -> [a] -> a
sumUntilEqLoop a
0

sumUntilEqLoop :: (Eq a, Num a) => a -> [a] -> a
sumUntilEqLoop :: forall a. (Eq a, Num a) => a -> [a] -> a
sumUntilEqLoop a
acc (a
x:a
y:[a]
xs) =
    if a
step0 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
step1
        then a
step0
        else a -> [a] -> a
forall a. (Eq a, Num a) => a -> [a] -> a
sumUntilEqLoop a
step1 [a]
xs
    where step0 :: a
step0 = a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a
x
          step1 :: a
step1 = a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y