-- |
-- Module      : Crypto.Number.Compat
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
{-# LANGUAGE CPP           #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE BangPatterns  #-}
{-# LANGUAGE UnboxedTuples #-}
module Crypto.Number.Compat
    ( GmpSupported(..)
    , onGmpUnsupported
    , gmpGcde
    , gmpLog2
    , gmpPowModSecInteger
    , gmpPowModInteger
    , gmpInverse
    , gmpNextPrime
    , gmpTestPrimeMillerRabin
    , gmpSizeInBytes
    , gmpSizeInBits
    , gmpExportInteger
    , gmpExportIntegerLE
    , gmpImportInteger
    , gmpImportIntegerLE
    ) where

#ifndef MIN_VERSION_integer_gmp
#define MIN_VERSION_integer_gmp(a,b,c) 0
#endif

#if MIN_VERSION_integer_gmp(0,5,1)
import GHC.Integer.GMP.Internals
import GHC.Base
import GHC.Integer.Logarithms (integerLog2#)
#endif
import Data.Word
import GHC.Ptr (Ptr(..))

-- | GMP Supported / Unsupported
data GmpSupported a = GmpSupported a
                    | GmpUnsupported
                    deriving (Int -> GmpSupported a -> ShowS
[GmpSupported a] -> ShowS
GmpSupported a -> String
(Int -> GmpSupported a -> ShowS)
-> (GmpSupported a -> String)
-> ([GmpSupported a] -> ShowS)
-> Show (GmpSupported a)
forall a. Show a => Int -> GmpSupported a -> ShowS
forall a. Show a => [GmpSupported a] -> ShowS
forall a. Show a => GmpSupported a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [GmpSupported a] -> ShowS
$cshowList :: forall a. Show a => [GmpSupported a] -> ShowS
show :: GmpSupported a -> String
$cshow :: forall a. Show a => GmpSupported a -> String
showsPrec :: Int -> GmpSupported a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> GmpSupported a -> ShowS
Show,GmpSupported a -> GmpSupported a -> Bool
(GmpSupported a -> GmpSupported a -> Bool)
-> (GmpSupported a -> GmpSupported a -> Bool)
-> Eq (GmpSupported a)
forall a. Eq a => GmpSupported a -> GmpSupported a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: GmpSupported a -> GmpSupported a -> Bool
$c/= :: forall a. Eq a => GmpSupported a -> GmpSupported a -> Bool
== :: GmpSupported a -> GmpSupported a -> Bool
$c== :: forall a. Eq a => GmpSupported a -> GmpSupported a -> Bool
Eq)

-- | Simple combinator in case the operation is not supported through GMP
onGmpUnsupported :: GmpSupported a -> a -> a
onGmpUnsupported :: GmpSupported a -> a -> a
onGmpUnsupported (GmpSupported a
a) a
_ = a
a
onGmpUnsupported GmpSupported a
GmpUnsupported   a
f = a
f

-- | Compute the GCDE of a two integer through GMP
gmpGcde :: Integer -> Integer -> GmpSupported (Integer, Integer, Integer)
#if MIN_VERSION_integer_gmp(0,5,1)
gmpGcde :: Integer -> Integer -> GmpSupported (Integer, Integer, Integer)
gmpGcde Integer
a Integer
b =
    (Integer, Integer, Integer)
-> GmpSupported (Integer, Integer, Integer)
forall a. a -> GmpSupported a
GmpSupported (Integer
s, Integer
t, Integer
g)
  where (# Integer
g, Integer
s #) = Integer -> Integer -> (# Integer, Integer #)
gcdExtInteger Integer
a Integer
b
        t :: Integer
t = (Integer
g Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
a) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
b
#else
gmpGcde _ _ = GmpUnsupported
#endif

-- | Compute the binary logarithm of an integer through GMP
gmpLog2 :: Integer -> GmpSupported Int
#if MIN_VERSION_integer_gmp(0,5,1)
gmpLog2 :: Integer -> GmpSupported Int
gmpLog2 Integer
0 = Int -> GmpSupported Int
forall a. a -> GmpSupported a
GmpSupported Int
0
gmpLog2 Integer
x = Int -> GmpSupported Int
forall a. a -> GmpSupported a
GmpSupported (Int# -> Int
I# (Integer -> Int#
integerLog2# Integer
x))
#else
gmpLog2 _ = GmpUnsupported
#endif

-- | Compute the power modulus using extra security to remain constant
-- time wise through GMP
gmpPowModSecInteger :: Integer -> Integer -> Integer -> GmpSupported Integer
#if MIN_VERSION_integer_gmp(1,1,0)
gmpPowModSecInteger _ _ _ = GmpUnsupported
#elif MIN_VERSION_integer_gmp(1,0,2)
gmpPowModSecInteger :: Integer -> Integer -> Integer -> GmpSupported Integer
gmpPowModSecInteger Integer
b Integer
e Integer
m = Integer -> GmpSupported Integer
forall a. a -> GmpSupported a
GmpSupported (Integer -> Integer -> Integer -> Integer
powModSecInteger Integer
b Integer
e Integer
m)
#elif MIN_VERSION_integer_gmp(1,0,0)
gmpPowModSecInteger _ _ _ = GmpUnsupported
#elif MIN_VERSION_integer_gmp(0,5,1)
gmpPowModSecInteger b e m = GmpSupported (powModSecInteger b e m)
#else
gmpPowModSecInteger _ _ _ = GmpUnsupported
#endif

-- | Compute the power modulus through GMP
gmpPowModInteger :: Integer -> Integer -> Integer -> GmpSupported Integer
#if MIN_VERSION_integer_gmp(0,5,1)
gmpPowModInteger :: Integer -> Integer -> Integer -> GmpSupported Integer
gmpPowModInteger Integer
b Integer
e Integer
m = Integer -> GmpSupported Integer
forall a. a -> GmpSupported a
GmpSupported (Integer -> Integer -> Integer -> Integer
powModInteger Integer
b Integer
e Integer
m)
#else
gmpPowModInteger _ _ _ = GmpUnsupported
#endif

-- | Inverse modulus of a number through GMP
gmpInverse :: Integer -> Integer -> GmpSupported (Maybe Integer)
#if MIN_VERSION_integer_gmp(0,5,1)
gmpInverse :: Integer -> Integer -> GmpSupported (Maybe Integer)
gmpInverse Integer
g Integer
m
    | Integer
r Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0    = Maybe Integer -> GmpSupported (Maybe Integer)
forall a. a -> GmpSupported a
GmpSupported Maybe Integer
forall a. Maybe a
Nothing
    | Bool
otherwise = Maybe Integer -> GmpSupported (Maybe Integer)
forall a. a -> GmpSupported a
GmpSupported (Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
r)
  where r :: Integer
r = Integer -> Integer -> Integer
recipModInteger Integer
g Integer
m
#else
gmpInverse _ _ = GmpUnsupported
#endif

-- | Get the next prime from a specific value through GMP
gmpNextPrime :: Integer -> GmpSupported Integer
#if MIN_VERSION_integer_gmp(1,1,0)
gmpNextPrime _ = GmpUnsupported
#elif MIN_VERSION_integer_gmp(0,5,1)
gmpNextPrime :: Integer -> GmpSupported Integer
gmpNextPrime Integer
n = Integer -> GmpSupported Integer
forall a. a -> GmpSupported a
GmpSupported (Integer -> Integer
nextPrimeInteger Integer
n)
#else
gmpNextPrime _ = GmpUnsupported
#endif

-- | Test if a number is prime using Miller Rabin
gmpTestPrimeMillerRabin :: Int -> Integer -> GmpSupported Bool
#if MIN_VERSION_integer_gmp(1,1,0)
gmpTestPrimeMillerRabin _ _ = GmpUnsupported
#elif MIN_VERSION_integer_gmp(0,5,1)
gmpTestPrimeMillerRabin :: Int -> Integer -> GmpSupported Bool
gmpTestPrimeMillerRabin (I# Int#
tries) !Integer
n = Bool -> GmpSupported Bool
forall a. a -> GmpSupported a
GmpSupported (Bool -> GmpSupported Bool) -> Bool -> GmpSupported Bool
forall a b. (a -> b) -> a -> b
$
    case Integer -> Int# -> Int#
testPrimeInteger Integer
n Int#
tries of
        Int#
0# -> Bool
False
        Int#
_  -> Bool
True
#else
gmpTestPrimeMillerRabin _ _ = GmpUnsupported
#endif

-- | Return the size in bytes of an integer
gmpSizeInBytes :: Integer -> GmpSupported Int
#if MIN_VERSION_integer_gmp(0,5,1)
gmpSizeInBytes :: Integer -> GmpSupported Int
gmpSizeInBytes Integer
n = Int -> GmpSupported Int
forall a. a -> GmpSupported a
GmpSupported (Int# -> Int
I# (Word# -> Int#
word2Int# (Integer -> Int# -> Word#
sizeInBaseInteger Integer
n Int#
256#)))
#else
gmpSizeInBytes _ = GmpUnsupported
#endif

-- | Return the size in bits of an integer
gmpSizeInBits :: Integer -> GmpSupported Int
#if MIN_VERSION_integer_gmp(0,5,1)
gmpSizeInBits :: Integer -> GmpSupported Int
gmpSizeInBits Integer
n = Int -> GmpSupported Int
forall a. a -> GmpSupported a
GmpSupported (Int# -> Int
I# (Word# -> Int#
word2Int# (Integer -> Int# -> Word#
sizeInBaseInteger Integer
n Int#
2#)))
#else
gmpSizeInBits _ = GmpUnsupported
#endif

-- | Export an integer to a memory (big-endian)
gmpExportInteger :: Integer -> Ptr Word8 -> GmpSupported (IO ())
#if MIN_VERSION_integer_gmp(1,0,0)
gmpExportInteger :: Integer -> Ptr Word8 -> GmpSupported (IO ())
gmpExportInteger Integer
n (Ptr Addr#
addr) = IO () -> GmpSupported (IO ())
forall a. a -> GmpSupported a
GmpSupported (IO () -> GmpSupported (IO ())) -> IO () -> GmpSupported (IO ())
forall a b. (a -> b) -> a -> b
$ do
    Word
_ <- Integer -> Addr# -> Int# -> IO Word
exportIntegerToAddr Integer
n Addr#
addr Int#
1#
    () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#elif MIN_VERSION_integer_gmp(0,5,1)
gmpExportInteger n (Ptr addr) = GmpSupported $ IO $ \s ->
    case exportIntegerToAddr n addr 1# s of
        (# s2, _ #) -> (# s2, () #)
#else
gmpExportInteger _ _ = GmpUnsupported
#endif

-- | Export an integer to a memory (little-endian)
gmpExportIntegerLE :: Integer -> Ptr Word8 -> GmpSupported (IO ())
#if MIN_VERSION_integer_gmp(1,0,0)
gmpExportIntegerLE :: Integer -> Ptr Word8 -> GmpSupported (IO ())
gmpExportIntegerLE Integer
n (Ptr Addr#
addr) = IO () -> GmpSupported (IO ())
forall a. a -> GmpSupported a
GmpSupported (IO () -> GmpSupported (IO ())) -> IO () -> GmpSupported (IO ())
forall a b. (a -> b) -> a -> b
$ do
    Word
_ <- Integer -> Addr# -> Int# -> IO Word
exportIntegerToAddr Integer
n Addr#
addr Int#
0#
    () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#elif MIN_VERSION_integer_gmp(0,5,1)
gmpExportIntegerLE n (Ptr addr) = GmpSupported $ IO $ \s ->
    case exportIntegerToAddr n addr 0# s of
        (# s2, _ #) -> (# s2, () #)
#else
gmpExportIntegerLE _ _ = GmpUnsupported
#endif

-- | Import an integer from a memory (big-endian)
gmpImportInteger :: Int -> Ptr Word8 -> GmpSupported (IO Integer)
#if MIN_VERSION_integer_gmp(1,0,0)
gmpImportInteger :: Int -> Ptr Word8 -> GmpSupported (IO Integer)
gmpImportInteger (I# Int#
n) (Ptr Addr#
addr) = IO Integer -> GmpSupported (IO Integer)
forall a. a -> GmpSupported a
GmpSupported (IO Integer -> GmpSupported (IO Integer))
-> IO Integer -> GmpSupported (IO Integer)
forall a b. (a -> b) -> a -> b
$
    Addr# -> Word# -> Int# -> IO Integer
importIntegerFromAddr Addr#
addr (Int# -> Word#
int2Word# Int#
n) Int#
1#
#elif MIN_VERSION_integer_gmp(0,5,1)
gmpImportInteger (I# n) (Ptr addr) = GmpSupported $ IO $ \s ->
    importIntegerFromAddr addr (int2Word# n) 1# s
#else
gmpImportInteger _ _ = GmpUnsupported
#endif

-- | Import an integer from a memory (little-endian)
gmpImportIntegerLE :: Int -> Ptr Word8 -> GmpSupported (IO Integer)
#if MIN_VERSION_integer_gmp(1,0,0)
gmpImportIntegerLE :: Int -> Ptr Word8 -> GmpSupported (IO Integer)
gmpImportIntegerLE (I# Int#
n) (Ptr Addr#
addr) = IO Integer -> GmpSupported (IO Integer)
forall a. a -> GmpSupported a
GmpSupported (IO Integer -> GmpSupported (IO Integer))
-> IO Integer -> GmpSupported (IO Integer)
forall a b. (a -> b) -> a -> b
$
    Addr# -> Word# -> Int# -> IO Integer
importIntegerFromAddr Addr#
addr (Int# -> Word#
int2Word# Int#
n) Int#
0#
#elif MIN_VERSION_integer_gmp(0,5,1)
gmpImportIntegerLE (I# n) (Ptr addr) = GmpSupported $ IO $ \s ->
    importIntegerFromAddr addr (int2Word# n) 0# s
#else
gmpImportIntegerLE _ _ = GmpUnsupported
#endif