{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

-- | GMP utilities.  A simple example with probable primes:
--
-- > import Numeric.GMP.Raw.Safe (mpz_nextprime)
-- >
-- > nextPrime :: Integer -> Integer
-- > nextPrime n =
-- >   unsafePerformIO $
-- >     withOutInteger_ $ \rop ->
-- >       withInInteger n $ \op ->
-- >         mpz_nextprime rop op
module Numeric.GMP.Utils
  ( -- * Integer marshalling
    withInInteger'
  , withInInteger
  , withInOutInteger
  , withInOutInteger_
  , withOutInteger
  , withOutInteger_
  , peekInteger'
  , peekInteger
  , pokeInteger
    -- * Rational marshalling
  , withInRational'
  , withInRational
  , withInOutRational
  , withInOutRational_
  , withOutRational
  , withOutRational_
  , peekRational'
  , peekRational
  , pokeRational
  ) where

import Control.Exception (bracket_)
import Data.Ratio ((%), numerator, denominator)
import Foreign (allocaBytes, alloca, with, sizeOf, peek)

#if MIN_VERSION_base(4,15,0)

#define GHC_BIGNUM 1
import GHC.Num.Integer
  ( Integer(..)
  , integerFromBigNat#
  , integerFromBigNatNeg#
  )
import GHC.Num.BigNat
  ( bigNatSize#
  )

#else

#define GHC_BIGNUM 0

import GHC.Integer.GMP.Internals
  ( Integer(..)
  , BigNat(..)
  , sizeofBigNat#
  , byteArrayToBigNat#
  , bigNatToInteger
  , bigNatToNegInteger
  )

#define IS S#

#endif


import GHC.Prim
  ( ByteArray#
  , sizeofByteArray#
  , copyByteArrayToAddr#
  , newByteArray#
  , copyAddrToByteArray#
  , unsafeFreezeByteArray#
  )
import GHC.Exts (Int(..), Ptr(..))
import GHC.Types (IO(..))

import Numeric.GMP.Types

import Numeric.GMP.Raw.Unsafe
  ( mpz_init
  , mpz_clear
  , mpq_init
  , mpq_clear
  , mpz_set
  )

foreign import ccall unsafe "mpz_set_HsInt" -- implemented in wrappers.c
  mpz_set_HsInt :: Ptr MPZ -> Int -> IO ()


-- | Store an 'Integer' into a temporary 'MPZ'.  The action must use it only
--   as an @mpz_srcptr@ (ie, constant/immutable), and must not allow references
--   to it to escape its scope.
withInInteger' :: Integer -> (MPZ -> IO r) -> IO r
withInInteger' :: forall r. Integer -> (MPZ -> IO r) -> IO r
withInInteger' Integer
i MPZ -> IO r
action = case Integer
i of
  IS Int#
n# -> forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr MPZ
src -> forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (Ptr MPZ -> IO ()
mpz_init Ptr MPZ
src) (Ptr MPZ -> IO ()
mpz_clear Ptr MPZ
src) forall a b. (a -> b) -> a -> b
$ do
    -- a bit awkward, TODO figure out how to do this without foreign calls?
    Ptr MPZ -> Int -> IO ()
mpz_set_HsInt Ptr MPZ
src (Int# -> Int
I# Int#
n#)
    MPZ
z <- forall a. Storable a => Ptr a -> IO a
peek Ptr MPZ
src
    r
r <- MPZ -> IO r
action MPZ
z
    forall (m :: * -> *) a. Monad m => a -> m a
return r
r
#if GHC_BIGNUM
  IP ByteArray#
ba# -> forall a r. ByteArray# -> (Ptr a -> Int -> IO r) -> IO r
withByteArray ByteArray#
ba# forall a b. (a -> b) -> a -> b
$ \Ptr MPLimb
d Int
_ -> MPZ -> IO r
action MPZ
        { mpzAlloc :: CInt
mpzAlloc = CInt
0
        , mpzSize :: CInt
mpzSize = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int# -> Int
I# (ByteArray# -> Int#
bigNatSize# ByteArray#
ba#))
        , mpzD :: Ptr MPLimb
mpzD = Ptr MPLimb
d
        }
  IN ByteArray#
ba# -> forall a r. ByteArray# -> (Ptr a -> Int -> IO r) -> IO r
withByteArray ByteArray#
ba# forall a b. (a -> b) -> a -> b
$ \Ptr MPLimb
d Int
_ -> MPZ -> IO r
action MPZ
        { mpzAlloc :: CInt
mpzAlloc = CInt
0
        , mpzSize :: CInt
mpzSize = - forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int# -> Int
I# (ByteArray# -> Int#
bigNatSize# ByteArray#
ba#))
        , mpzD :: Ptr MPLimb
mpzD = Ptr MPLimb
d
        }
#else
  Jp# bn@(BN# ba#) -> withByteArray ba# $ \d _ -> action MPZ
        { mpzAlloc = 0
        , mpzSize = fromIntegral (I# (sizeofBigNat# bn))
        , mpzD = d
        }
  Jn# bn@(BN# ba#) -> withByteArray ba# $ \d _ -> action MPZ
        { mpzAlloc = 0
        , mpzSize = - fromIntegral (I# (sizeofBigNat# bn))
        , mpzD = d
        }
#endif

withByteArray :: ByteArray# -> (Ptr a -> Int -> IO r) -> IO r
withByteArray :: forall a r. ByteArray# -> (Ptr a -> Int -> IO r) -> IO r
withByteArray ByteArray#
ba# Ptr a -> Int -> IO r
f = do
  let bytes :: Int
bytes = Int# -> Int
I# (ByteArray# -> Int#
sizeofByteArray# ByteArray#
ba#)
  forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes Int
bytes forall a b. (a -> b) -> a -> b
$ \ptr :: Ptr a
ptr@(Ptr Addr#
addr#) -> do
    forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO (\State# RealWorld
s -> (# forall d.
ByteArray# -> Int# -> Addr# -> Int# -> State# d -> State# d
copyByteArrayToAddr# ByteArray#
ba# Int#
0# Addr#
addr# (ByteArray# -> Int#
sizeofByteArray# ByteArray#
ba#) State# RealWorld
s, () #))
    Ptr a -> Int -> IO r
f Ptr a
ptr Int
bytes


-- | Combination of 'withInInteger'' and 'with'.  The action must use it only
--   as an @mpz_srcptr@ (ie, constant/immutable), and must not allow the pointer
--   to escape its scope.  If in doubt about potential mutation by the action,
--   use 'withInOutInteger' instead.
withInInteger :: Integer -> (Ptr MPZ -> IO r) -> IO r
withInInteger :: forall r. Integer -> (Ptr MPZ -> IO r) -> IO r
withInInteger Integer
i Ptr MPZ -> IO r
action = forall r. Integer -> (MPZ -> IO r) -> IO r
withInInteger' Integer
i forall a b. (a -> b) -> a -> b
$ \MPZ
z -> forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with MPZ
z Ptr MPZ -> IO r
action


-- | Allocates and initializes an @mpz_t@, pokes the value, and peeks and clears
--   it after the action.  The pointer must not escape the scope of the action.
withInOutInteger :: Integer -> (Ptr MPZ -> IO a) -> IO (Integer, a)
withInOutInteger :: forall a. Integer -> (Ptr MPZ -> IO a) -> IO (Integer, a)
withInOutInteger Integer
n Ptr MPZ -> IO a
action = forall a. (Ptr MPZ -> IO a) -> IO (Integer, a)
withOutInteger forall a b. (a -> b) -> a -> b
$ \Ptr MPZ
z -> do
  Ptr MPZ -> Integer -> IO ()
pokeInteger Ptr MPZ
z Integer
n
  Ptr MPZ -> IO a
action Ptr MPZ
z


-- | Allocates and initializes an @mpz_t@, pokes the value, and peeks and clears
--   it after the action.  The pointer must not escape the scope of the action.
--   The result of the action is discarded.
withInOutInteger_ :: Integer -> (Ptr MPZ -> IO a) -> IO Integer
withInOutInteger_ :: forall a. Integer -> (Ptr MPZ -> IO a) -> IO Integer
withInOutInteger_ Integer
n Ptr MPZ -> IO a
action = do
  (Integer
z, a
_) <- forall a. Integer -> (Ptr MPZ -> IO a) -> IO (Integer, a)
withInOutInteger Integer
n Ptr MPZ -> IO a
action
  forall (m :: * -> *) a. Monad m => a -> m a
return Integer
z


-- | Allocates and initializes an @mpz_t@, then peeks and clears it after the
--   action.  The pointer must not escape the scope of the action.
withOutInteger :: (Ptr MPZ -> IO a) -> IO (Integer, a)
withOutInteger :: forall a. (Ptr MPZ -> IO a) -> IO (Integer, a)
withOutInteger Ptr MPZ -> IO a
action = forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr MPZ
ptr ->
  forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (Ptr MPZ -> IO ()
mpz_init Ptr MPZ
ptr) (Ptr MPZ -> IO ()
mpz_clear Ptr MPZ
ptr) forall a b. (a -> b) -> a -> b
$ do
    a
a <- Ptr MPZ -> IO a
action Ptr MPZ
ptr
    Integer
z <- Ptr MPZ -> IO Integer
peekInteger Ptr MPZ
ptr
    forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
z, a
a)


-- | Allocates and initializes an @mpz_t@, then peeks and clears it after the
--   action.  The pointer must not escape the scope of the action.  The result
--   of the action is discarded.
withOutInteger_ :: (Ptr MPZ -> IO a) -> IO Integer
withOutInteger_ :: forall a. (Ptr MPZ -> IO a) -> IO Integer
withOutInteger_ Ptr MPZ -> IO a
action = do
  (Integer
z, a
_) <- forall a. (Ptr MPZ -> IO a) -> IO (Integer, a)
withOutInteger Ptr MPZ -> IO a
action
  forall (m :: * -> *) a. Monad m => a -> m a
return Integer
z


-- | Store an 'Integer' into an @mpz_t@, which must have been initialized with
--   @mpz_init@.
pokeInteger :: Ptr MPZ -> Integer -> IO ()
pokeInteger :: Ptr MPZ -> Integer -> IO ()
pokeInteger Ptr MPZ
dst (IS Int#
n#) = Ptr MPZ -> Int -> IO ()
mpz_set_HsInt Ptr MPZ
dst (Int# -> Int
I# Int#
n#)
-- copies twice, once in withInteger, and again in @mpz_set@.
-- could maybe rewrite to do one copy, using gmp's own alloc functions?
pokeInteger Ptr MPZ
dst Integer
j = forall r. Integer -> (Ptr MPZ -> IO r) -> IO r
withInInteger Integer
j forall a b. (a -> b) -> a -> b
$ Ptr MPZ -> Ptr MPZ -> IO ()
mpz_set Ptr MPZ
dst


-- | Read an 'Integer' from an 'MPZ'.
peekInteger' :: MPZ -> IO Integer
peekInteger' :: MPZ -> IO Integer
peekInteger' MPZ{ mpzSize :: MPZ -> CInt
mpzSize = CInt
size, mpzD :: MPZ -> Ptr MPLimb
mpzD = Ptr MPLimb
d } = do
  if CInt
size forall a. Eq a => a -> a -> Bool
== CInt
0 then forall (m :: * -> *) a. Monad m => a -> m a
return Integer
0 else
-- This copies once, from 'Ptr' 'MPLimb' to 'ByteArray#'
-- 'byteArrayToBigNat#' hopefully won't need to copy it again
    forall a r. Ptr a -> Int -> (ByteArray# -> IO r) -> IO r
asByteArray Ptr MPLimb
d (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Num a => a -> a
abs CInt
size) forall a. Num a => a -> a -> a
* forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: MPLimb))
#if GHC_BIGNUM
      (\ByteArray#
ba# -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if CInt
size forall a. Ord a => a -> a -> Bool
< CInt
0
         then ByteArray# -> Integer
integerFromBigNatNeg# ByteArray#
ba#
         else ByteArray# -> Integer
integerFromBigNat#    ByteArray#
ba#
      )
#else
      (\ba# -> return $ case fromIntegral (abs size) of
        I# size# -> (if size < 0 then bigNatToNegInteger else bigNatToInteger)
            (byteArrayToBigNat# ba# size#)
      )
#endif

asByteArray :: Ptr a -> Int -> (ByteArray# -> IO r) -> IO r
asByteArray :: forall a r. Ptr a -> Int -> (ByteArray# -> IO r) -> IO r
asByteArray (Ptr Addr#
addr#) (I# Int#
bytes#) ByteArray# -> IO r
f = do
  forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s# -> case forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray# Int#
bytes# State# RealWorld
s# of
    (# State# RealWorld
s'#, MutableByteArray# RealWorld
mba# #) ->
      case forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# RealWorld
mba# (forall d.
Addr#
-> MutableByteArray# d -> Int# -> Int# -> State# d -> State# d
copyAddrToByteArray# Addr#
addr# MutableByteArray# RealWorld
mba# Int#
0# Int#
bytes# State# RealWorld
s'#) of
        (# State# RealWorld
s''#, ByteArray#
ba# #) -> case ByteArray# -> IO r
f ByteArray#
ba# of IO State# RealWorld -> (# State# RealWorld, r #)
r -> State# RealWorld -> (# State# RealWorld, r #)
r State# RealWorld
s''#


-- | Combination of 'peek' and 'peekInteger''.
peekInteger :: Ptr MPZ -> IO Integer
peekInteger :: Ptr MPZ -> IO Integer
peekInteger Ptr MPZ
src = do
  MPZ
z <- forall a. Storable a => Ptr a -> IO a
peek Ptr MPZ
src
  MPZ -> IO Integer
peekInteger' MPZ
z


-- | Store a 'Rational' into a temporary 'MPQ'.  The action must use it only
--   as an @mpq_srcptr@ (ie, constant/immutable), and must not allow the pointer
--   to escape its scope.
withInRational' :: Rational -> (MPQ -> IO r) -> IO r
withInRational' :: forall r. Rational -> (MPQ -> IO r) -> IO r
withInRational' Rational
q MPQ -> IO r
action =
  forall r. Integer -> (MPZ -> IO r) -> IO r
withInInteger' (forall a. Ratio a -> a
numerator Rational
q) forall a b. (a -> b) -> a -> b
$ \MPZ
nz ->
  forall r. Integer -> (MPZ -> IO r) -> IO r
withInInteger' (forall a. Ratio a -> a
denominator Rational
q) forall a b. (a -> b) -> a -> b
$ \MPZ
dz ->
  MPQ -> IO r
action (MPZ -> MPZ -> MPQ
MPQ MPZ
nz MPZ
dz)


-- | Combination of 'withInRational'' and 'with'.  The action must use it only
--   as an @mpq_srcptr@ (ie, constant/immutable), and must not allow the pointer
--   to escape its scope.  If in doubt about potential mutation by the action,
--   use 'withInOutRational' instead.
withInRational :: Rational -> (Ptr MPQ -> IO r) -> IO r
withInRational :: forall r. Rational -> (Ptr MPQ -> IO r) -> IO r
withInRational Rational
q Ptr MPQ -> IO r
action = forall r. Rational -> (MPQ -> IO r) -> IO r
withInRational' Rational
q forall a b. (a -> b) -> a -> b
$ \MPQ
qq -> forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with MPQ
qq Ptr MPQ -> IO r
action


-- | Allocates and initializes an @mpq_t@, pokes the value, and peeks and clears
--   it after the action.  The pointer must not escaep the scope of the action.
withInOutRational :: Rational -> (Ptr MPQ -> IO a) -> IO (Rational, a)
withInOutRational :: forall a. Rational -> (Ptr MPQ -> IO a) -> IO (Rational, a)
withInOutRational Rational
n Ptr MPQ -> IO a
action = forall a. (Ptr MPQ -> IO a) -> IO (Rational, a)
withOutRational forall a b. (a -> b) -> a -> b
$ \Ptr MPQ
q -> do
  Ptr MPQ -> Rational -> IO ()
pokeRational Ptr MPQ
q Rational
n
  Ptr MPQ -> IO a
action Ptr MPQ
q


-- | Allocates and initializes an @mpq_t@, pokes the value, and peeks and clears
--   it after the action.  The pointer must not escaep the scope of the action.
--   The result of the action is discarded.
withInOutRational_ :: Rational -> (Ptr MPQ -> IO a) -> IO Rational
withInOutRational_ :: forall a. Rational -> (Ptr MPQ -> IO a) -> IO Rational
withInOutRational_ Rational
n Ptr MPQ -> IO a
action = do
  (Rational
q, a
_) <- forall a. Rational -> (Ptr MPQ -> IO a) -> IO (Rational, a)
withInOutRational Rational
n Ptr MPQ -> IO a
action
  forall (m :: * -> *) a. Monad m => a -> m a
return Rational
q


-- | Allocates and initializes an @mpq_t@, then peeks and clears it after the
--   action.  The pointer must not escape the scope of the action.
withOutRational :: (Ptr MPQ -> IO a) -> IO (Rational, a)
withOutRational :: forall a. (Ptr MPQ -> IO a) -> IO (Rational, a)
withOutRational Ptr MPQ -> IO a
action = forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca forall a b. (a -> b) -> a -> b
$ \Ptr MPQ
ptr ->
  forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (Ptr MPQ -> IO ()
mpq_init Ptr MPQ
ptr) (Ptr MPQ -> IO ()
mpq_clear Ptr MPQ
ptr) forall a b. (a -> b) -> a -> b
$ do
    a
a <- Ptr MPQ -> IO a
action Ptr MPQ
ptr
    Rational
q <- Ptr MPQ -> IO Rational
peekRational Ptr MPQ
ptr
    forall (m :: * -> *) a. Monad m => a -> m a
return (Rational
q, a
a)


-- | Allocates and initializes an @mpq_t@, then peeks and clears it after the
--   action.  The pointer must not escape the scope of the action.  The result
--   of the action is discarded.
withOutRational_ :: (Ptr MPQ -> IO a) -> IO Rational
withOutRational_ :: forall a. (Ptr MPQ -> IO a) -> IO Rational
withOutRational_ Ptr MPQ -> IO a
action = do
  (Rational
q, a
_) <- forall a. (Ptr MPQ -> IO a) -> IO (Rational, a)
withOutRational Ptr MPQ -> IO a
action
  forall (m :: * -> *) a. Monad m => a -> m a
return Rational
q


-- | Store a 'Rational' into an @mpq_t@, which must have been initialized with
--   @mpq_init@.
pokeRational :: Ptr MPQ -> Rational -> IO ()
pokeRational :: Ptr MPQ -> Rational -> IO ()
pokeRational Ptr MPQ
ptr Rational
q = do
  Ptr MPZ -> Integer -> IO ()
pokeInteger (Ptr MPQ -> Ptr MPZ
mpq_numref Ptr MPQ
ptr) (forall a. Ratio a -> a
numerator Rational
q)
  Ptr MPZ -> Integer -> IO ()
pokeInteger (Ptr MPQ -> Ptr MPZ
mpq_denref Ptr MPQ
ptr) (forall a. Ratio a -> a
denominator Rational
q)


-- | Read a 'Rational' from an 'MPQ'.
peekRational' :: MPQ -> IO Rational
peekRational' :: MPQ -> IO Rational
peekRational' (MPQ MPZ
n MPZ
d) = do
  Integer
num <- MPZ -> IO Integer
peekInteger' MPZ
n
  Integer
den <- MPZ -> IO Integer
peekInteger' MPZ
d
  forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
num forall a. Integral a => a -> a -> Ratio a
% Integer
den)


-- | Combination of 'peek' and 'peekRational''.
peekRational :: Ptr MPQ -> IO Rational
peekRational :: Ptr MPQ -> IO Rational
peekRational Ptr MPQ
src = do
  MPQ
q <- forall a. Storable a => Ptr a -> IO a
peek Ptr MPQ
src
  MPQ -> IO Rational
peekRational' MPQ
q