-- |
-- Module:      Math.NumberTheory.Curves.Montgomery
-- Copyright:   (c) 2017 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Arithmetic on Montgomery elliptic curves.
-- This is an internal module, exposed only for purposes of testing.
--

{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE MagicHash           #-}
{-# LANGUAGE ScopedTypeVariables #-}

{-# OPTIONS_GHC -fno-warn-type-defaults #-}
{-# OPTIONS_HADDOCK hide #-}

module Math.NumberTheory.Curves.Montgomery
  ( Point
  , pointX
  , pointZ
  , pointN
  , pointA24
  , SomePoint(..)
  , newPoint
  , add
  , double
  , multiply
  ) where

import Data.Proxy
import GHC.Exts
import GHC.Integer.Logarithms
import GHC.TypeNats (KnownNat, SomeNat(..), Nat, natVal, someNatVal)

import Math.NumberTheory.Utils (recipMod)

-- | We use the Montgomery form of elliptic curve:
-- b Y² = X³ + a X² + X (mod n).
-- See Eq. (10.3.1.1) at p. 260 of <http://www.ams.org/journals/mcom/1987-48-177/S0025-5718-1987-0866113-7/S0025-5718-1987-0866113-7.pdf Speeding the Pollard and Elliptic Curve Methods of Factorization> by P. L. Montgomery.
--
-- Switching to projective space by substitutions Y = y \/ z, X = x \/ z,
-- we get b y² z = x³ + a x² z + x z² (mod n).
-- The point on projective elliptic curve is characterized by three coordinates,
-- but it appears that only x- and z-components matter for computations.
-- By the same reason there is no need to store coefficient b.
--
-- That said, the chosen curve is represented by a24 = (a + 2) \/ 4
-- and modulo n at type level, making points on different curves
-- incompatible.
data Point (a24 :: Nat) (n :: Nat) = Point
  { Point a24 n -> Integer
pointX :: !Integer -- ^ Extract x-coordinate.
  , Point a24 n -> Integer
pointZ :: !Integer -- ^ Extract z-coordinate.
  }

-- | Extract (a + 2) \/ 4, where a is a coefficient in curve's equation.
pointA24 :: forall a24 n. KnownNat a24 => Point a24 n -> Integer
pointA24 :: Point a24 n -> Integer
pointA24 Point a24 n
_ = Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Natural -> Integer) -> Natural -> Integer
forall a b. (a -> b) -> a -> b
$ Proxy a24 -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal (Proxy a24
forall k (t :: k). Proxy t
Proxy :: Proxy a24)

-- | Extract modulo of the curve.
pointN :: forall a24 n. KnownNat n => Point a24 n -> Integer
pointN :: Point a24 n -> Integer
pointN Point a24 n
_ = Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Natural -> Integer) -> Natural -> Integer
forall a b. (a -> b) -> a -> b
$ Proxy n -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal (Proxy n
forall k (t :: k). Proxy t
Proxy :: Proxy n)

-- | In projective space 'Point's are equal, if they are both at infinity
-- or if respective ratios 'pointX' \/ 'pointZ' are equal.
instance KnownNat n => Eq (Point a24 n) where
  Point Integer
_ Integer
0 == :: Point a24 n -> Point a24 n -> Bool
== Point Integer
_ Integer
0 = Bool
True
  Point Integer
_ Integer
0 == Point a24 n
_         = Bool
False
  Point a24 n
_         == Point Integer
_ Integer
0 = Bool
False
  p :: Point a24 n
p@(Point Integer
x1 Integer
z1) == Point Integer
x2 Integer
z2 = let n :: Integer
n = Point a24 n -> Integer
forall (a24 :: Nat) (n :: Nat).
KnownNat n =>
Point a24 n -> Integer
pointN Point a24 n
p in (Integer
x1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
z2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
x2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
z1) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0

-- | For debugging.
instance (KnownNat a24, KnownNat n) => Show (Point a24 n) where
  show :: Point a24 n -> String
show Point a24 n
p = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Point a24 n -> Integer
forall (a24 :: Nat) (n :: Nat). Point a24 n -> Integer
pointX Point a24 n
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Point a24 n -> Integer
forall (a24 :: Nat) (n :: Nat). Point a24 n -> Integer
pointZ Point a24 n
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") (a24 "
    String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Point a24 n -> Integer
forall (a24 :: Nat) (n :: Nat).
KnownNat a24 =>
Point a24 n -> Integer
pointA24 Point a24 n
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", mod "
    String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Point a24 n -> Integer
forall (a24 :: Nat) (n :: Nat).
KnownNat n =>
Point a24 n -> Integer
pointN Point a24 n
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

-- | Point on unknown curve.
data SomePoint where
  SomePoint :: (KnownNat a24, KnownNat n) => Point a24 n -> SomePoint

instance Show SomePoint where
  show :: SomePoint -> String
show (SomePoint Point a24 n
p) = Point a24 n -> String
forall a. Show a => a -> String
show Point a24 n
p

-- | 'newPoint' @s@ @n@ creates a point on an elliptic curve modulo @n@, uniquely determined by seed @s@.
-- Some choices of @s@ and @n@ produce ill-parametrized curves, which is reflected by return value 'Nothing'.
--
-- We choose a curve by Suyama's parametrization. See Eq. (3)-(4) at p. 4
-- of <http://www.hyperelliptic.org/tanja/SHARCS/talks06/Gaj.pdf Implementing the Elliptic Curve Method of Factoring in Reconfigurable Hardware>
-- by K. Gaj, S. Kwon et al.
newPoint :: Integer -> Integer -> Maybe SomePoint
newPoint :: Integer -> Integer -> Maybe SomePoint
newPoint Integer
s Integer
n = do
    Integer
a24denRecip <- Integer -> Integer -> Maybe Integer
recipMod Integer
a24den Integer
n
    Integer
a24 <- case Integer
a24num Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
a24denRecip Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n of
      -- (a+2)/4 = 0 corresponds to singular curve with A = -2
      Integer
0 -> Maybe Integer
forall a. Maybe a
Nothing
      -- (a+2)/4 = 1 corresponds to singular curve with A = 2
      Integer
1 -> Maybe Integer
forall a. Maybe a
Nothing
      Integer
t -> Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
t
    SomeNat (Proxy n
_ :: Proxy a24Ty) <- if Integer
a24 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0
                                  then Maybe SomeNat
forall a. Maybe a
Nothing
                                  else SomeNat -> Maybe SomeNat
forall a. a -> Maybe a
Just (SomeNat -> Maybe SomeNat) -> SomeNat -> Maybe SomeNat
forall a b. (a -> b) -> a -> b
$ Natural -> SomeNat
someNatVal (Natural -> SomeNat) -> Natural -> SomeNat
forall a b. (a -> b) -> a -> b
$ Integer -> Natural
forall a. Num a => Integer -> a
fromInteger Integer
a24
    SomeNat (Proxy n
_ :: Proxy nTy)   <- if Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0
                                  then Maybe SomeNat
forall a. Maybe a
Nothing
                                  else SomeNat -> Maybe SomeNat
forall a. a -> Maybe a
Just (SomeNat -> Maybe SomeNat) -> SomeNat -> Maybe SomeNat
forall a b. (a -> b) -> a -> b
$ Natural -> SomeNat
someNatVal (Natural -> SomeNat) -> Natural -> SomeNat
forall a b. (a -> b) -> a -> b
$ Integer -> Natural
forall a. Num a => Integer -> a
fromInteger Integer
n
    SomePoint -> Maybe SomePoint
forall (m :: * -> *) a. Monad m => a -> m a
return (SomePoint -> Maybe SomePoint) -> SomePoint -> Maybe SomePoint
forall a b. (a -> b) -> a -> b
$ Point n n -> SomePoint
forall (a24 :: Nat) (n :: Nat).
(KnownNat a24, KnownNat n) =>
Point a24 n -> SomePoint
SomePoint (Integer -> Integer -> Point n n
forall (a24 :: Nat) (n :: Nat). Integer -> Integer -> Point a24 n
Point Integer
x Integer
z :: Point a24Ty nTy)
  where
    u :: Integer
u = Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
s Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
5
    v :: Integer
v = Integer
4 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
s
    d :: Integer
d = Integer
v Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
u
    x :: Integer
x = Integer
u Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
u Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
u Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
    z :: Integer
z = Integer
v Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
v Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
v Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
    a24num :: Integer
a24num = Integer
d Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
d Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
d Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
3 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
u Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
v) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
    a24den :: Integer
a24den = Integer
16 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
v Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n

-- | If @p0@ + @p1@ = @p2@, then 'add' @p0@ @p1@ @p2@ equals to @p1@ + @p2@.
-- It is also required that z-coordinates of @p0@, @p1@ and @p2@ are coprime with modulo
-- of elliptic curve; and x-coordinate of @p0@ is non-zero.
-- If preconditions do not hold, return value is undefined.
--
-- Remarkably such addition does not require 'KnownNat' @a24@ constraint.
--
-- Computations follow Algorithm 3 at p. 4
-- of <http://www.hyperelliptic.org/tanja/SHARCS/talks06/Gaj.pdf Implementing the Elliptic Curve Method of Factoring in Reconfigurable Hardware>
-- by K. Gaj, S. Kwon et al.
add :: KnownNat n => Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
add :: Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
add p0 :: Point a24 n
p0@(Point Integer
x0 Integer
z0) (Point Integer
x1 Integer
z1) (Point Integer
x2 Integer
z2) = Integer -> Integer -> Point a24 n
forall (a24 :: Nat) (n :: Nat). Integer -> Integer -> Point a24 n
Point Integer
x3 Integer
z3
  where
    n :: Integer
n = Point a24 n -> Integer
forall (a24 :: Nat) (n :: Nat).
KnownNat n =>
Point a24 n -> Integer
pointN Point a24 n
p0
    a :: Integer
a = (Integer
x1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
z1) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
x2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
z2) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n
    b :: Integer
b = (Integer
x1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
z1) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
x2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
z2) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n
    apb :: Integer
apb = Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
b
    amb :: Integer
amb = Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
b
    c :: Integer
c = Integer
apb Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
apb Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n
    d :: Integer
d = Integer
amb Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
amb Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n
    x3 :: Integer
x3 = Integer
c Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
z0 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
    z3 :: Integer
z3 = Integer
d Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
x0 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n

-- | Multiply by 2.
--
-- Computations follow Algorithm 3 at p. 4
-- of <http://www.hyperelliptic.org/tanja/SHARCS/talks06/Gaj.pdf Implementing the Elliptic Curve Method of Factoring in Reconfigurable Hardware>
-- by K. Gaj, S. Kwon et al.
double :: (KnownNat a24, KnownNat n) => Point a24 n -> Point a24 n
double :: Point a24 n -> Point a24 n
double p :: Point a24 n
p@(Point Integer
x Integer
z) = Integer -> Integer -> Point a24 n
forall (a24 :: Nat) (n :: Nat). Integer -> Integer -> Point a24 n
Point Integer
x' Integer
z'
  where
    n :: Integer
n = Point a24 n -> Integer
forall (a24 :: Nat) (n :: Nat).
KnownNat n =>
Point a24 n -> Integer
pointN Point a24 n
p
    a24 :: Integer
a24 = Point a24 n -> Integer
forall (a24 :: Nat) (n :: Nat).
KnownNat a24 =>
Point a24 n -> Integer
pointA24 Point a24 n
p
    r :: Integer
r = Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
z
    s :: Integer
s = Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
z
    u :: Integer
u = Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
r Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n
    v :: Integer
v = Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
s Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n
    t :: Integer
t = Integer
u Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
v
    x' :: Integer
x' = Integer
u Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
v Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n
    z' :: Integer
z' = (Integer
v Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
a24 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
t Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
n) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
t Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
n

-- | Multiply by given number, using binary algorithm.
multiply :: (KnownNat a24, KnownNat n) => Word -> Point a24 n -> Point a24 n
multiply :: Word -> Point a24 n -> Point a24 n
multiply Word
0 Point a24 n
_ = Integer -> Integer -> Point a24 n
forall (a24 :: Nat) (n :: Nat). Integer -> Integer -> Point a24 n
Point Integer
0 Integer
0
multiply Word
1 Point a24 n
p = Point a24 n
p
multiply (W# Word#
w##) Point a24 n
p =
    case Word# -> Int#
wordLog2# Word#
w## of
      Int#
l# -> Int# -> Point a24 n -> Point a24 n -> Point a24 n
go (Int#
l# Int# -> Int# -> Int#
-# Int#
1#) Point a24 n
p (Point a24 n -> Point a24 n
forall (a24 :: Nat) (n :: Nat).
(KnownNat a24, KnownNat n) =>
Point a24 n -> Point a24 n
double Point a24 n
p)
  where
    go :: Int# -> Point a24 n -> Point a24 n -> Point a24 n
go Int#
0# !Point a24 n
p0 !Point a24 n
p1 = case Word#
w## Word# -> Word# -> Word#
`and#` Word#
1## of
                      Word#
0## -> Point a24 n -> Point a24 n
forall (a24 :: Nat) (n :: Nat).
(KnownNat a24, KnownNat n) =>
Point a24 n -> Point a24 n
double Point a24 n
p0
                      Word#
_   -> Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
forall (n :: Nat) (a24 :: Nat).
KnownNat n =>
Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
add Point a24 n
p Point a24 n
p0 Point a24 n
p1
    go Int#
i# Point a24 n
p0 Point a24 n
p1 = case Word# -> Int# -> Word#
uncheckedShiftRL# Word#
w## Int#
i# Word# -> Word# -> Word#
`and#` Word#
1## of
                    Word#
0## -> Int# -> Point a24 n -> Point a24 n -> Point a24 n
go (Int#
i# Int# -> Int# -> Int#
-# Int#
1#) (Point a24 n -> Point a24 n
forall (a24 :: Nat) (n :: Nat).
(KnownNat a24, KnownNat n) =>
Point a24 n -> Point a24 n
double Point a24 n
p0) (Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
forall (n :: Nat) (a24 :: Nat).
KnownNat n =>
Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
add Point a24 n
p Point a24 n
p0 Point a24 n
p1)
                    Word#
_   -> Int# -> Point a24 n -> Point a24 n -> Point a24 n
go (Int#
i# Int# -> Int# -> Int#
-# Int#
1#) (Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
forall (n :: Nat) (a24 :: Nat).
KnownNat n =>
Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
add Point a24 n
p Point a24 n
p0 Point a24 n
p1) (Point a24 n -> Point a24 n
forall (a24 :: Nat) (n :: Nat).
(KnownNat a24, KnownNat n) =>
Point a24 n -> Point a24 n
double Point a24 n
p1)