{-|
Module      : Numeric.Combinatorics
Copyright   : Copyright (c) 2018 Vanessa McHale

This provides facilities for working with common combinatorial
functions.
-}

module Numeric.Combinatorics ( choose
                             , doubleFactorial
                             , catalan
                             , factorial
                             , derangement
                             , permutations
                             , maxRegions
                             , stirling2
                             ) where

import           Foreign.C
import           Foreign.Ptr
import           Numeric.GMP.Raw.Unsafe (mpz_clear)
import           Numeric.GMP.Types
import           Numeric.GMP.Utils
import           System.IO.Unsafe       (unsafeDupablePerformIO)

foreign import ccall unsafe double_factorial_ats :: CInt -> IO (Ptr MPZ)
foreign import ccall unsafe factorial_ats :: CInt -> IO (Ptr MPZ)
foreign import ccall unsafe choose_ats :: CInt -> CInt -> IO (Ptr MPZ)
foreign import ccall unsafe catalan_ats :: CInt -> IO (Ptr MPZ)
foreign import ccall unsafe derangements_ats :: CInt -> IO (Ptr MPZ)
foreign import ccall unsafe permutations_ats :: CInt -> CInt -> IO (Ptr MPZ)
foreign import ccall unsafe max_regions_ats :: CInt -> IO (Ptr MPZ)
foreign import ccall unsafe stirling2_ats :: CInt -> CInt -> IO (Ptr MPZ)

conjugateMPZ :: (CInt -> IO (Ptr MPZ)) -> Int -> Integer
conjugateMPZ :: (CInt -> IO (Ptr MPZ)) -> Int -> Integer
conjugateMPZ CInt -> IO (Ptr MPZ)
f Int
n = IO Integer -> Integer
forall a. IO a -> a
unsafeDupablePerformIO (IO Integer -> Integer) -> IO Integer -> Integer
forall a b. (a -> b) -> a -> b
$ do
    Ptr MPZ
mPtr <- CInt -> IO (Ptr MPZ)
f (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
    Ptr MPZ -> IO Integer
peekInteger Ptr MPZ
mPtr IO Integer -> IO () -> IO Integer
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Ptr MPZ -> IO ()
mpz_clear Ptr MPZ
mPtr

conjugateMPZ' :: (CInt -> CInt -> IO (Ptr MPZ)) -> Int -> Int -> Integer
conjugateMPZ' :: (CInt -> CInt -> IO (Ptr MPZ)) -> Int -> Int -> Integer
conjugateMPZ' CInt -> CInt -> IO (Ptr MPZ)
f Int
n Int
k = IO Integer -> Integer
forall a. IO a -> a
unsafeDupablePerformIO (IO Integer -> Integer) -> IO Integer -> Integer
forall a b. (a -> b) -> a -> b
$ do
    Ptr MPZ
mPtr <- CInt -> CInt -> IO (Ptr MPZ)
f (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k)
    Ptr MPZ -> IO Integer
peekInteger Ptr MPZ
mPtr IO Integer -> IO () -> IO Integer
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Ptr MPZ -> IO ()
mpz_clear Ptr MPZ
mPtr

-- | \( !n \)
--
-- > λ:> derangement <$> [0..10]
-- > [1,0,1,2,9,44,265,1854,14833,133496,1334961]
derangement :: Int -> Integer
derangement :: Int -> Integer
derangement = (CInt -> IO (Ptr MPZ)) -> Int -> Integer
conjugateMPZ CInt -> IO (Ptr MPZ)
derangements_ats

-- | The @n@th Catalan number, with indexing beginning at @0@.
--
-- > λ:> catalan <$> [0..9]
-- > [1,1,2,5,14,42,132,429,1430,4862]
catalan :: Int -> Integer
catalan :: Int -> Integer
catalan = (CInt -> IO (Ptr MPZ)) -> Int -> Integer
conjugateMPZ CInt -> IO (Ptr MPZ)
catalan_ats

-- | \( \binom{n}{k} \)
choose :: Int -> Int -> Integer
choose :: Int -> Int -> Integer
choose = (CInt -> CInt -> IO (Ptr MPZ)) -> Int -> Int -> Integer
conjugateMPZ' CInt -> CInt -> IO (Ptr MPZ)
choose_ats

permutations :: Int -> Int -> Integer
permutations :: Int -> Int -> Integer
permutations = (CInt -> CInt -> IO (Ptr MPZ)) -> Int -> Int -> Integer
conjugateMPZ' CInt -> CInt -> IO (Ptr MPZ)
permutations_ats

-- | Stirling numbers of the second kind.
stirling2 :: Int -> Int -> Integer
stirling2 :: Int -> Int -> Integer
stirling2 = (CInt -> CInt -> IO (Ptr MPZ)) -> Int -> Int -> Integer
conjugateMPZ' CInt -> CInt -> IO (Ptr MPZ)
stirling2_ats

factorial :: Int -> Integer
factorial :: Int -> Integer
factorial = (CInt -> IO (Ptr MPZ)) -> Int -> Integer
conjugateMPZ CInt -> IO (Ptr MPZ)
factorial_ats

-- | \( n!! \)
doubleFactorial :: Int -> Integer
doubleFactorial :: Int -> Integer
doubleFactorial = (CInt -> IO (Ptr MPZ)) -> Int -> Integer
conjugateMPZ CInt -> IO (Ptr MPZ)
double_factorial_ats

-- | Compute the maximal number of regions obtained by joining \( n \) points
-- about a circle by straight lines. See [here](https://oeis.org/A000127).
maxRegions :: Int -- ^ \( n \)
           -> Integer
maxRegions :: Int -> Integer
maxRegions = (CInt -> IO (Ptr MPZ)) -> Int -> Integer
conjugateMPZ CInt -> IO (Ptr MPZ)
max_regions_ats