-- |
-- Module      : Crypto.ECC.Edwards25519
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- Arithmetic primitives over curve edwards25519.
--
-- Twisted Edwards curves are a familly of elliptic curves allowing
-- complete addition formulas without any special case and no point at
-- infinity.  Curve edwards25519 is based on prime 2^255 - 19 for
-- efficient implementation.  Equation and parameters are given in
-- <https://tools.ietf.org/html/rfc7748 RFC 7748>.
--
-- This module provides types and primitive operations that are useful
-- to implement cryptographic schemes based on curve edwards25519:
--
-- - arithmetic functions for point addition, doubling, negation,
-- scalar multiplication with an arbitrary point, with the base point,
-- etc.
--
-- - arithmetic functions dealing with scalars modulo the prime order
-- L of the base point
--
-- All functions run in constant time unless noted otherwise.
--
-- Warnings:
--
-- 1. Curve edwards25519 has a cofactor h = 8 so the base point does
-- not generate the entire curve and points with order 2, 4, 8 exist.
-- When implementing cryptographic algorithms, special care must be
-- taken using one of the following methods:
--
--     - points must be checked for membership in the prime-order
--     subgroup
--
--     - or cofactor must be cleared by multiplying points by 8
--
--     Utility functions are provided to implement this.  Testing
--     subgroup membership with 'pointHasPrimeOrder' is 50-time slower
--     than call 'pointMulByCofactor'.
--
-- 2. Scalar arithmetic is always reduced modulo L, allowing fixed
-- length and constant execution time, but this reduction is valid
-- only when points are in the prime-order subgroup.
--
-- 3. Because of modular reduction in this implementation it is not
-- possible to multiply points directly by scalars like 8.s or L.
-- This has to be decomposed into several steps.
--
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.ECC.Edwards25519
    ( Scalar
    , Point
    -- * Scalars
    , scalarGenerate
    , scalarDecodeLong
    , scalarEncode
    -- * Points
    , pointDecode
    , pointEncode
    , pointHasPrimeOrder
    -- * Arithmetic functions
    , toPoint
    , scalarAdd
    , scalarMul
    , pointNegate
    , pointAdd
    , pointDouble
    , pointMul
    , pointMulByCofactor
    , pointsMulVarTime
    ) where

import           Data.Word
import           Foreign.C.Types
import           Foreign.Ptr

import           Crypto.Error
import           Crypto.Internal.ByteArray (Bytes, ScrubbedBytes, withByteArray)
import qualified Crypto.Internal.ByteArray as B
import           Crypto.Internal.Compat
import           Crypto.Internal.Imports
import           Crypto.Random


scalarArraySize :: Int
scalarArraySize :: Int
scalarArraySize = Int
40 -- maximum [9 * 4 {- 32 bits -}, 5 * 8 {- 64 bits -}]

-- | A scalar modulo prime order of curve edwards25519.
newtype Scalar = Scalar ScrubbedBytes
    deriving (Int -> Scalar -> ShowS
[Scalar] -> ShowS
Scalar -> String
(Int -> Scalar -> ShowS)
-> (Scalar -> String) -> ([Scalar] -> ShowS) -> Show Scalar
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Scalar] -> ShowS
$cshowList :: [Scalar] -> ShowS
show :: Scalar -> String
$cshow :: Scalar -> String
showsPrec :: Int -> Scalar -> ShowS
$cshowsPrec :: Int -> Scalar -> ShowS
Show,Scalar -> ()
(Scalar -> ()) -> NFData Scalar
forall a. (a -> ()) -> NFData a
rnf :: Scalar -> ()
$crnf :: Scalar -> ()
NFData)

instance Eq Scalar where
    (Scalar ScrubbedBytes
s1) == :: Scalar -> Scalar -> Bool
== (Scalar ScrubbedBytes
s2) = IO Bool -> Bool
forall a. IO a -> a
unsafeDoIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$
        ScrubbedBytes -> (Ptr Scalar -> IO Bool) -> IO Bool
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
s1 ((Ptr Scalar -> IO Bool) -> IO Bool)
-> (Ptr Scalar -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
ps1 ->
        ScrubbedBytes -> (Ptr Scalar -> IO Bool) -> IO Bool
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
s2 ((Ptr Scalar -> IO Bool) -> IO Bool)
-> (Ptr Scalar -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
ps2 ->
            (CInt -> Bool) -> IO CInt -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0) (Ptr Scalar -> Ptr Scalar -> IO CInt
ed25519_scalar_eq Ptr Scalar
ps1 Ptr Scalar
ps2)
    {-# NOINLINE (==) #-}

pointArraySize :: Int
pointArraySize :: Int
pointArraySize = Int
160 -- maximum [4 * 10 * 4 {- 32 bits -}, 4 * 5 * 8 {- 64 bits -}]

-- | A point on curve edwards25519.
newtype Point = Point Bytes
    deriving Point -> ()
(Point -> ()) -> NFData Point
forall a. (a -> ()) -> NFData a
rnf :: Point -> ()
$crnf :: Point -> ()
NFData

instance Show Point where
    showsPrec :: Int -> Point -> ShowS
showsPrec Int
d Point
p =
        let bs :: Bytes
bs = Point -> Bytes
forall bs. ByteArray bs => Point -> bs
pointEncode Point
p :: Bytes
         in Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$ String -> ShowS
showString String
"Point "
                               ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> ShowS
forall a. Show a => a -> ShowS
shows (Base -> Bytes -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
Base -> bin -> bout
B.convertToBase Base
B.Base16 Bytes
bs :: Bytes)

instance Eq Point where
    (Point Bytes
p1) == :: Point -> Point -> Bool
== (Point Bytes
p2) = IO Bool -> Bool
forall a. IO a -> a
unsafeDoIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$
        Bytes -> (Ptr Point -> IO Bool) -> IO Bool
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
p1 ((Ptr Point -> IO Bool) -> IO Bool)
-> (Ptr Point -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pp1 ->
        Bytes -> (Ptr Point -> IO Bool) -> IO Bool
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
p2 ((Ptr Point -> IO Bool) -> IO Bool)
-> (Ptr Point -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pp2 ->
            (CInt -> Bool) -> IO CInt -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0) (Ptr Point -> Ptr Point -> IO CInt
ed25519_point_eq Ptr Point
pp1 Ptr Point
pp2)
    {-# NOINLINE (==) #-}

-- | Generate a random scalar.
scalarGenerate :: MonadRandom randomly => randomly Scalar
scalarGenerate :: randomly Scalar
scalarGenerate = CryptoFailable Scalar -> Scalar
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable Scalar -> Scalar)
-> (ScrubbedBytes -> CryptoFailable Scalar)
-> ScrubbedBytes
-> Scalar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> CryptoFailable Scalar
forall bs. ByteArrayAccess bs => bs -> CryptoFailable Scalar
scalarDecodeLong (ScrubbedBytes -> Scalar)
-> randomly ScrubbedBytes -> randomly Scalar
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> randomly ScrubbedBytes
forall (randomly :: * -> *).
MonadRandom randomly =>
randomly ScrubbedBytes
generate
  where
    -- Scalar generation is based on a fixed number of bytes so that
    -- there is no timing leak.  But because of modular reduction
    -- distribution is not uniform.  We use many more bytes than
    -- necessary so the probability bias is small.  With 512 bits we
    -- get 22% of scalars with a higher frequency, but the relative
    -- probability difference is only 2^(-260).
    generate :: MonadRandom randomly => randomly ScrubbedBytes
    generate :: randomly ScrubbedBytes
generate = Int -> randomly ScrubbedBytes
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
64

-- | Serialize a scalar to binary, i.e. a 32-byte little-endian
-- number.
scalarEncode :: B.ByteArray bs => Scalar -> bs
scalarEncode :: Scalar -> bs
scalarEncode (Scalar ScrubbedBytes
s) =
    Int -> (Ptr Word8 -> IO ()) -> bs
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
32 ((Ptr Word8 -> IO ()) -> bs) -> (Ptr Word8 -> IO ()) -> bs
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
out ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
s ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
ps -> Ptr Word8 -> Ptr Scalar -> IO ()
ed25519_scalar_encode Ptr Word8
out Ptr Scalar
ps

-- | Deserialize a little-endian number as a scalar.  Input array can
-- have any length from 0 to 64 bytes.
--
-- Note: it is not advised to put secret information in the 3 lowest
-- bits of a scalar if this scalar may be multiplied to untrusted
-- points outside the prime-order subgroup.
scalarDecodeLong :: B.ByteArrayAccess bs => bs -> CryptoFailable Scalar
scalarDecodeLong :: bs -> CryptoFailable Scalar
scalarDecodeLong bs
bs
    | bs -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bs
bs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
64 = CryptoError -> CryptoFailable Scalar
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_EcScalarOutOfBounds
    | Bool
otherwise        = IO (CryptoFailable Scalar) -> CryptoFailable Scalar
forall a. IO a -> a
unsafeDoIO (IO (CryptoFailable Scalar) -> CryptoFailable Scalar)
-> IO (CryptoFailable Scalar) -> CryptoFailable Scalar
forall a b. (a -> b) -> a -> b
$ bs
-> (Ptr Word8 -> IO (CryptoFailable Scalar))
-> IO (CryptoFailable Scalar)
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray bs
bs Ptr Word8 -> IO (CryptoFailable Scalar)
initialize
  where
    len :: CSize
len = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CSize) -> Int -> CSize
forall a b. (a -> b) -> a -> b
$ bs -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bs
bs
    initialize :: Ptr Word8 -> IO (CryptoFailable Scalar)
initialize Ptr Word8
inp = do
        ScrubbedBytes
s <- Int -> (Ptr Scalar -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
scalarArraySize ((Ptr Scalar -> IO ()) -> IO ScrubbedBytes)
-> (Ptr Scalar -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
ps ->
                 Ptr Scalar -> Ptr Word8 -> CSize -> IO ()
ed25519_scalar_decode_long Ptr Scalar
ps Ptr Word8
inp CSize
len
        CryptoFailable Scalar -> IO (CryptoFailable Scalar)
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptoFailable Scalar -> IO (CryptoFailable Scalar))
-> CryptoFailable Scalar -> IO (CryptoFailable Scalar)
forall a b. (a -> b) -> a -> b
$ Scalar -> CryptoFailable Scalar
forall a. a -> CryptoFailable a
CryptoPassed (ScrubbedBytes -> Scalar
Scalar ScrubbedBytes
s)
{-# NOINLINE scalarDecodeLong #-}

-- | Add two scalars.
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd (Scalar ScrubbedBytes
a) (Scalar ScrubbedBytes
b) =
    ScrubbedBytes -> Scalar
Scalar (ScrubbedBytes -> Scalar) -> ScrubbedBytes -> Scalar
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Scalar -> IO ()) -> ScrubbedBytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
scalarArraySize ((Ptr Scalar -> IO ()) -> ScrubbedBytes)
-> (Ptr Scalar -> IO ()) -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
out ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
a ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
pa ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
b ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
pb ->
             Ptr Scalar -> Ptr Scalar -> Ptr Scalar -> IO ()
ed25519_scalar_add Ptr Scalar
out Ptr Scalar
pa Ptr Scalar
pb

-- | Multiply two scalars.
scalarMul :: Scalar -> Scalar -> Scalar
scalarMul :: Scalar -> Scalar -> Scalar
scalarMul (Scalar ScrubbedBytes
a) (Scalar ScrubbedBytes
b) =
    ScrubbedBytes -> Scalar
Scalar (ScrubbedBytes -> Scalar) -> ScrubbedBytes -> Scalar
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Scalar -> IO ()) -> ScrubbedBytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
scalarArraySize ((Ptr Scalar -> IO ()) -> ScrubbedBytes)
-> (Ptr Scalar -> IO ()) -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
out ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
a ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
pa ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
b ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
pb ->
             Ptr Scalar -> Ptr Scalar -> Ptr Scalar -> IO ()
ed25519_scalar_mul Ptr Scalar
out Ptr Scalar
pa Ptr Scalar
pb

-- | Multiplies a scalar with the curve base point.
toPoint :: Scalar -> Point
toPoint :: Scalar -> Point
toPoint (Scalar ScrubbedBytes
scalar) =
    Bytes -> Point
Point (Bytes -> Point) -> Bytes -> Point
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Point -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
pointArraySize ((Ptr Point -> IO ()) -> Bytes) -> (Ptr Point -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Point
out ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
scalar ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
pscalar ->
            Ptr Point -> Ptr Scalar -> IO ()
ed25519_point_base_scalarmul Ptr Point
out Ptr Scalar
pscalar

-- | Serialize a point to a 32-byte array.
--
-- Format is binary compatible with 'Crypto.PubKey.Ed25519.PublicKey'
-- from module "Crypto.PubKey.Ed25519".
pointEncode :: B.ByteArray bs => Point -> bs
pointEncode :: Point -> bs
pointEncode (Point Bytes
p) =
    Int -> (Ptr Word8 -> IO ()) -> bs
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
32 ((Ptr Word8 -> IO ()) -> bs) -> (Ptr Word8 -> IO ()) -> bs
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
out ->
        Bytes -> (Ptr Point -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
p ((Ptr Point -> IO ()) -> IO ()) -> (Ptr Point -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pp ->
             Ptr Word8 -> Ptr Point -> IO ()
ed25519_point_encode Ptr Word8
out Ptr Point
pp

-- | Deserialize a 32-byte array as a point, ensuring the point is
-- valid on edwards25519.
--
-- /WARNING:/ variable time
pointDecode :: B.ByteArrayAccess bs => bs -> CryptoFailable Point
pointDecode :: bs -> CryptoFailable Point
pointDecode bs
bs
    | bs -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bs
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = IO (CryptoFailable Point) -> CryptoFailable Point
forall a. IO a -> a
unsafeDoIO (IO (CryptoFailable Point) -> CryptoFailable Point)
-> IO (CryptoFailable Point) -> CryptoFailable Point
forall a b. (a -> b) -> a -> b
$ bs
-> (Ptr Word8 -> IO (CryptoFailable Point))
-> IO (CryptoFailable Point)
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray bs
bs Ptr Word8 -> IO (CryptoFailable Point)
initialize
    | Bool
otherwise         = CryptoError -> CryptoFailable Point
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_PointSizeInvalid
  where
    initialize :: Ptr Word8 -> IO (CryptoFailable Point)
initialize Ptr Word8
inp = do
        (CInt
res, Bytes
p) <- Int -> (Ptr Point -> IO CInt) -> IO (CInt, Bytes)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
pointArraySize ((Ptr Point -> IO CInt) -> IO (CInt, Bytes))
-> (Ptr Point -> IO CInt) -> IO (CInt, Bytes)
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pp ->
                        Ptr Point -> Ptr Word8 -> IO CInt
ed25519_point_decode_vartime Ptr Point
pp Ptr Word8
inp
        if CInt
res CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0 then CryptoFailable Point -> IO (CryptoFailable Point)
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptoFailable Point -> IO (CryptoFailable Point))
-> CryptoFailable Point -> IO (CryptoFailable Point)
forall a b. (a -> b) -> a -> b
$ CryptoError -> CryptoFailable Point
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_PointCoordinatesInvalid
                    else CryptoFailable Point -> IO (CryptoFailable Point)
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptoFailable Point -> IO (CryptoFailable Point))
-> CryptoFailable Point -> IO (CryptoFailable Point)
forall a b. (a -> b) -> a -> b
$ Point -> CryptoFailable Point
forall a. a -> CryptoFailable a
CryptoPassed (Bytes -> Point
Point Bytes
p)
{-# NOINLINE pointDecode #-}

-- | Test whether a point belongs to the prime-order subgroup
-- generated by the base point.  Result is 'True' for the identity
-- point.
--
-- @
-- pointHasPrimeOrder p = 'pointNegate' p == 'pointMul' l_minus_one p
-- @
pointHasPrimeOrder :: Point -> Bool
pointHasPrimeOrder :: Point -> Bool
pointHasPrimeOrder (Point Bytes
p) = IO Bool -> Bool
forall a. IO a -> a
unsafeDoIO (IO Bool -> Bool) -> IO Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    Bytes -> (Ptr Point -> IO Bool) -> IO Bool
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
p ((Ptr Point -> IO Bool) -> IO Bool)
-> (Ptr Point -> IO Bool) -> IO Bool
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pp ->
        (CInt -> Bool) -> IO CInt -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0) (Ptr Point -> IO CInt
ed25519_point_has_prime_order Ptr Point
pp)
{-# NOINLINE pointHasPrimeOrder #-}

-- | Negate a point.
pointNegate :: Point -> Point
pointNegate :: Point -> Point
pointNegate (Point Bytes
a) =
    Bytes -> Point
Point (Bytes -> Point) -> Bytes -> Point
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Point -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
pointArraySize ((Ptr Point -> IO ()) -> Bytes) -> (Ptr Point -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Point
out ->
        Bytes -> (Ptr Point -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
a ((Ptr Point -> IO ()) -> IO ()) -> (Ptr Point -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pa ->
             Ptr Point -> Ptr Point -> IO ()
ed25519_point_negate Ptr Point
out Ptr Point
pa

-- | Add two points.
pointAdd :: Point -> Point -> Point
pointAdd :: Point -> Point -> Point
pointAdd (Point Bytes
a) (Point Bytes
b) =
    Bytes -> Point
Point (Bytes -> Point) -> Bytes -> Point
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Point -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
pointArraySize ((Ptr Point -> IO ()) -> Bytes) -> (Ptr Point -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Point
out ->
        Bytes -> (Ptr Point -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
a ((Ptr Point -> IO ()) -> IO ()) -> (Ptr Point -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pa ->
        Bytes -> (Ptr Point -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
b ((Ptr Point -> IO ()) -> IO ()) -> (Ptr Point -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pb ->
             Ptr Point -> Ptr Point -> Ptr Point -> IO ()
ed25519_point_add Ptr Point
out Ptr Point
pa Ptr Point
pb

-- | Add a point to itself.
--
-- @
-- pointDouble p = 'pointAdd' p p
-- @
pointDouble :: Point -> Point
pointDouble :: Point -> Point
pointDouble (Point Bytes
a) =
    Bytes -> Point
Point (Bytes -> Point) -> Bytes -> Point
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Point -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
pointArraySize ((Ptr Point -> IO ()) -> Bytes) -> (Ptr Point -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Point
out ->
        Bytes -> (Ptr Point -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
a ((Ptr Point -> IO ()) -> IO ()) -> (Ptr Point -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pa ->
             Ptr Point -> Ptr Point -> IO ()
ed25519_point_double Ptr Point
out Ptr Point
pa

-- | Multiply a point by h = 8.
--
-- @
-- pointMulByCofactor p = 'pointMul' scalar_8 p
-- @
pointMulByCofactor :: Point -> Point
pointMulByCofactor :: Point -> Point
pointMulByCofactor (Point Bytes
a) =
    Bytes -> Point
Point (Bytes -> Point) -> Bytes -> Point
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Point -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
pointArraySize ((Ptr Point -> IO ()) -> Bytes) -> (Ptr Point -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Point
out ->
        Bytes -> (Ptr Point -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
a ((Ptr Point -> IO ()) -> IO ()) -> (Ptr Point -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pa ->
             Ptr Point -> Ptr Point -> IO ()
ed25519_point_mul_by_cofactor Ptr Point
out Ptr Point
pa

-- | Scalar multiplication over curve edwards25519.
--
-- Note: when the scalar had reduction modulo L and the input point
-- has a torsion component, the output point may not be in the
-- expected subgroup.
pointMul :: Scalar -> Point -> Point
pointMul :: Scalar -> Point -> Point
pointMul (Scalar ScrubbedBytes
scalar) (Point Bytes
base) =
    Bytes -> Point
Point (Bytes -> Point) -> Bytes -> Point
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Point -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
pointArraySize ((Ptr Point -> IO ()) -> Bytes) -> (Ptr Point -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Point
out ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
scalar ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
pscalar ->
        Bytes -> (Ptr Point -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
base   ((Ptr Point -> IO ()) -> IO ()) -> (Ptr Point -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pbase   ->
             Ptr Point -> Ptr Point -> Ptr Scalar -> IO ()
ed25519_point_scalarmul Ptr Point
out Ptr Point
pbase Ptr Scalar
pscalar

-- | Multiply the point @p@ with @s2@ and add a lifted to curve value @s1@.
--
-- @
-- pointsMulVarTime s1 s2 p = 'pointAdd' ('toPoint' s1) ('pointMul' s2 p)
-- @
--
-- /WARNING:/ variable time
pointsMulVarTime :: Scalar -> Scalar -> Point -> Point
pointsMulVarTime :: Scalar -> Scalar -> Point -> Point
pointsMulVarTime (Scalar ScrubbedBytes
s1) (Scalar ScrubbedBytes
s2) (Point Bytes
p) =
    Bytes -> Point
Point (Bytes -> Point) -> Bytes -> Point
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Point -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
pointArraySize ((Ptr Point -> IO ()) -> Bytes) -> (Ptr Point -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Point
out ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
s1 ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
ps1 ->
        ScrubbedBytes -> (Ptr Scalar -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray ScrubbedBytes
s2 ((Ptr Scalar -> IO ()) -> IO ()) -> (Ptr Scalar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Scalar
ps2 ->
        Bytes -> (Ptr Point -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
withByteArray Bytes
p  ((Ptr Point -> IO ()) -> IO ()) -> (Ptr Point -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Point
pp  ->
             Ptr Point -> Ptr Scalar -> Ptr Point -> Ptr Scalar -> IO ()
ed25519_base_double_scalarmul_vartime Ptr Point
out Ptr Scalar
ps1 Ptr Point
pp Ptr Scalar
ps2

foreign import ccall unsafe "cryptonite_ed25519_scalar_eq"
    ed25519_scalar_eq :: Ptr Scalar
                      -> Ptr Scalar
                      -> IO CInt

foreign import ccall unsafe "cryptonite_ed25519_scalar_encode"
    ed25519_scalar_encode :: Ptr Word8
                          -> Ptr Scalar
                          -> IO ()

foreign import ccall unsafe "cryptonite_ed25519_scalar_decode_long"
    ed25519_scalar_decode_long :: Ptr Scalar
                               -> Ptr Word8
                               -> CSize
                               -> IO ()

foreign import ccall unsafe "cryptonite_ed25519_scalar_add"
    ed25519_scalar_add :: Ptr Scalar -- sum
                       -> Ptr Scalar -- a
                       -> Ptr Scalar -- b
                       -> IO ()

foreign import ccall unsafe "cryptonite_ed25519_scalar_mul"
    ed25519_scalar_mul :: Ptr Scalar -- out
                       -> Ptr Scalar -- a
                       -> Ptr Scalar -- b
                       -> IO ()

foreign import ccall unsafe "cryptonite_ed25519_point_encode"
    ed25519_point_encode :: Ptr Word8
                         -> Ptr Point
                         -> IO ()

foreign import ccall unsafe "cryptonite_ed25519_point_decode_vartime"
    ed25519_point_decode_vartime :: Ptr Point
                                 -> Ptr Word8
                                 -> IO CInt

foreign import ccall unsafe "cryptonite_ed25519_point_eq"
    ed25519_point_eq :: Ptr Point
                     -> Ptr Point
                     -> IO CInt

foreign import ccall "cryptonite_ed25519_point_has_prime_order"
    ed25519_point_has_prime_order :: Ptr Point
                                  -> IO CInt

foreign import ccall unsafe "cryptonite_ed25519_point_negate"
    ed25519_point_negate :: Ptr Point -- minus_a
                         -> Ptr Point -- a
                         -> IO ()

foreign import ccall unsafe "cryptonite_ed25519_point_add"
    ed25519_point_add :: Ptr Point -- sum
                      -> Ptr Point -- a
                      -> Ptr Point -- b
                      -> IO ()

foreign import ccall unsafe "cryptonite_ed25519_point_double"
    ed25519_point_double :: Ptr Point -- two_a
                         -> Ptr Point -- a
                         -> IO ()

foreign import ccall unsafe "cryptonite_ed25519_point_mul_by_cofactor"
    ed25519_point_mul_by_cofactor :: Ptr Point -- eight_a
                                  -> Ptr Point -- a
                                  -> IO ()

foreign import ccall "cryptonite_ed25519_point_base_scalarmul"
    ed25519_point_base_scalarmul :: Ptr Point  -- scaled
                                 -> Ptr Scalar -- scalar
                                 -> IO ()

foreign import ccall "cryptonite_ed25519_point_scalarmul"
    ed25519_point_scalarmul :: Ptr Point  -- scaled
                            -> Ptr Point  -- base
                            -> Ptr Scalar -- scalar
                            -> IO ()

foreign import ccall "cryptonite_ed25519_base_double_scalarmul_vartime"
    ed25519_base_double_scalarmul_vartime :: Ptr Point  -- combo
                                          -> Ptr Scalar -- scalar1
                                          -> Ptr Point  -- base2
                                          -> Ptr Scalar -- scalar2
                                          -> IO ()