{-|
Module      : Crypto.PastaCurves.Curves (internal)
Description : Supports the instantiation of parameterized prime-order elliptic curves.
Copyright   : (c) Eric Schorn, 2022
Maintainer  : eric.schorn@nccgroup.com
Stability   : experimental
Portability : GHC
SPDX-License-Identifier: MIT

This internal module provides an elliptic curve (multi-use) template from arbitrary
parameters for a curve of odd order, along with a variety of supporting functionality 
such as point addition, multiplication, negation, equality, serialization and 
deserialization. The algorithms are NOT constant time. Safety and simplicity are the 
top priorities; the curve order must be prime (and so affine curve point y-cord != 0).
-}

{-# LANGUAGE CPP, DataKinds, DerivingStrategies, FlexibleInstances, PolyKinds #-}
{-# LANGUAGE MultiParamTypeClasses, NoImplicitPrelude, Safe, ScopedTypeVariables #-}

module Curves (Curve(..), CurvePt(..), Point(..)) where

import Prelude hiding (drop, length, sqrt)
import Control.Monad (mfilter)
import Data.ByteString (ByteString, cons, drop, index, length, pack)
import Data.Maybe (fromJust)
import Data.Typeable (Proxy (Proxy))
import GHC.TypeLits (Nat, KnownNat, natVal)
import Fields (Field (..))


-- | The `Point` type incorporates type literals @a@ and @b@ of an elliptic curve in the
-- short Weierstrass normal form. It also incorporates @baseX@ and @baseY@ coordinates
-- for the base type. A point with different literals is considered a different type, so
-- cannot be inadvertently mixed. *The curve order must be prime, and `_y` cannot be zero*.
data Point (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f = 
  Projective {Point a b baseX baseY f -> f
_x :: f, Point a b baseX baseY f -> f
_y :: f, Point a b baseX baseY f -> f
_z :: f} deriving stock (Int -> Point a b baseX baseY f -> ShowS
[Point a b baseX baseY f] -> ShowS
Point a b baseX baseY f -> String
(Int -> Point a b baseX baseY f -> ShowS)
-> (Point a b baseX baseY f -> String)
-> ([Point a b baseX baseY f] -> ShowS)
-> Show (Point a b baseX baseY f)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
Int -> Point a b baseX baseY f -> ShowS
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
[Point a b baseX baseY f] -> ShowS
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
Point a b baseX baseY f -> String
showList :: [Point a b baseX baseY f] -> ShowS
$cshowList :: forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
[Point a b baseX baseY f] -> ShowS
show :: Point a b baseX baseY f -> String
$cshow :: forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
Point a b baseX baseY f -> String
showsPrec :: Int -> Point a b baseX baseY f -> ShowS
$cshowsPrec :: forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
Int -> Point a b baseX baseY f -> ShowS
Show) -- (x * inv0 z, y * inv0 z)
            

-- CPP macro 'helpers' to extract the curve parameters from `Point a b baseX baseY f`
#define A natVal (Proxy :: Proxy a)
#define B natVal (Proxy :: Proxy b)
#define BASE_X natVal (Proxy :: Proxy baseX)
#define BASE_Y natVal (Proxy :: Proxy baseY)


-- Calculate equality for projective points
instance (Field f, KnownNat a, KnownNat b, KnownNat baseX, KnownNat baseY) =>
  Eq (Point a b baseX baseY f) where
    
  -- x1/z1 == x2/z2 -> x1*z2/(x2*z1) == 1 -> same for y -> x1*z2/(x2*z1) == y1*z2/(y2*z1)
  -- All (neutral) points at infinity are equal.  
  == :: Point a b baseX baseY f -> Point a b baseX baseY f -> Bool
(==) (Projective f
x1 f
y1 f
z1) (Projective f
x2 f
y2 f
z2) = 
    (f
x1 f -> f -> f
forall a. Num a => a -> a -> a
* f
z2 f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
x2 f -> f -> f
forall a. Num a => a -> a -> a
* f
z1) Bool -> Bool -> Bool
&& (f
y1 f -> f -> f
forall a. Num a => a -> a -> a
* f
z2 f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
y2 f -> f -> f
forall a. Num a => a -> a -> a
* f
z1)


-- | The `CurvePt` class provides the bulk of the functionality related to operations
-- involving points on the elliptic curve. It supports both the Pallas and Vesta curve
-- point types, as well as any other curves (using the arbitrary curve parameters). The
-- curve order must be prime.
class CurvePt a where

  -- | Returns the (constant) base point.
  base :: a
  
  -- | The `fromBytesC` function deserializes a compressed point from a ByteString. An 
  -- invalid ByteString will return @Nothing@.
  fromBytesC :: ByteString -> Maybe a

  -- | The `isOnCurve` function validates whether the point is on the curve. It is 
  -- already utilized within `fromBytesC` deserialization, within hash-to-curve (for
  -- redundant safety) and within `toBytesC` serialization.
  isOnCurve :: a -> Bool

  -- | The `negatePt` function negates a point.
  negatePt :: a -> a

  -- | Returns the (constant) neutral point.
  neutral :: a

  -- | The `pointAdd` function adds two curve points on the same elliptic curve.
  pointAdd :: a -> a -> a

  -- | The `toBytesC` function serializes a point to a (compressed) @ByteStream@.
  toBytesC :: a -> ByteString


instance (Field f, KnownNat a, KnownNat b, KnownNat baseX, KnownNat baseY) =>
  CurvePt (Point a b baseX baseY f) where

  -- Construct the base point directly from the type literals.
  base :: Point a b baseX baseY f
base = f -> f -> f -> Point a b baseX baseY f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective (Integer -> f
forall a. Num a => Integer -> a
fromInteger (Integer -> f) -> Integer -> f
forall a b. (a -> b) -> a -> b
$ Proxy baseX -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
BASE_X) (fromInteger $ BASE_Y) 1


  -- Deserialize a ByteString into a point on the elliptic curve based on section 2.3.4
  -- of https://www.secg.org/sec1-v2.pdf. Only compressed points are supported.
  fromBytesC :: ByteString -> Maybe (Point a b baseX baseY f)
fromBytesC ByteString
bytes
    -- If the ByteString is a single zero byte, the Just neutral point is returned
    | ByteString -> Int
length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& ByteString -> Int -> Word8
index ByteString
bytes Int
0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0 = Point a b baseX baseY f -> Maybe (Point a b baseX baseY f)
forall a. a -> Maybe a
Just Point a b baseX baseY f
forall a. CurvePt a => a
neutral
    -- If the ByteString is correct length with an acceptable leading byte, attempt decode
    | ByteString -> Int
length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
corLen Bool -> Bool -> Bool
&& (ByteString -> Int -> Word8
index ByteString
bytes Int
0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x2 Bool -> Bool -> Bool
|| ByteString -> Int -> Word8
index ByteString
bytes Int
0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x03) = Maybe (Point a b baseX baseY f)
result
        where
          -- correct length is the correct length of the field element plus 1
          corLen :: Int
corLen = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
length (f -> ByteString
forall a. Field a => a -> ByteString
toBytesF (Integer -> f
forall a. Num a => Integer -> a
fromInteger (A) :: f))
          -- drop the leading byte then deserialize the x-coordinate
          x :: Maybe f
x = ByteString -> Maybe f
forall a. Field a => ByteString -> Maybe a
fromBytesF (Int -> ByteString -> ByteString
drop Int
1 ByteString
bytes) :: Maybe f
          -- see what sgn0 we are expecting for the final y-coordinate
          sgn0y :: Integer
sgn0y = if ByteString -> Int -> Word8
index ByteString
bytes Int
0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x02 then Integer
0 else Integer
1 :: Integer
          -- calculate y squared from deserialized x-coordinate and curve constants
          alpha :: Maybe f
alpha = (\f
t -> f
t f -> Integer -> f
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer) f -> f -> f
forall a. Num a => a -> a -> a
+ ((Integer -> f
forall a. Num a => Integer -> a
fromInteger (Integer -> f) -> Integer -> f
forall a b. (a -> b) -> a -> b
$ A) :: f) * t + ((fromInteger $ B) :: f -> f -> f
forall a. Num a => a -> a -> a
f)) <$> x
          -- get square root (thus a proposed y-coordinate; note y cannot be zero)
          beta :: Maybe f
beta = Maybe f
alpha Maybe f -> (f -> Maybe f) -> Maybe f
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= f -> Maybe f
forall a. Field a => a -> Maybe a
sqrt
          -- adjust which root is selected for y-coordinate
          y :: Maybe f
y =  (\f
t -> if f -> Integer
forall a. Field a => a -> Integer
sgn0 f
t Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
sgn0y then f
t else f -> f
forall a. Num a => a -> a
negate f
t) (f -> f) -> Maybe f -> Maybe f
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe f
beta
          -- propose a deserialized point (which is on the curve by construction)
          proposed :: Maybe (Point a b baseX baseY f)
proposed = (f -> f -> f -> Point a b baseX baseY f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective (f -> f -> f -> Point a b baseX baseY f)
-> Maybe f -> Maybe (f -> f -> Point a b baseX baseY f)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe f
x Maybe (f -> f -> Point a b baseX baseY f)
-> Maybe f -> Maybe (f -> Point a b baseX baseY f)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe f
y Maybe (f -> Point a b baseX baseY f)
-> Maybe f -> Maybe (Point a b baseX baseY f)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f -> Maybe f
forall a. a -> Maybe a
Just f
1) :: Maybe (Point a b baseX baseY f)
          -- re-validate it is on the curve and return; a sqrt fail propagates through Maybes
          result :: Maybe (Point a b baseX baseY f)
result = (Point a b baseX baseY f -> Bool)
-> Maybe (Point a b baseX baseY f)
-> Maybe (Point a b baseX baseY f)
forall (m :: * -> *) a. MonadPlus m => (a -> Bool) -> m a -> m a
mfilter Point a b baseX baseY f -> Bool
forall a. CurvePt a => a -> Bool
isOnCurve Maybe (Point a b baseX baseY f)
proposed
  -- Otherwise we fail (bad length, bad prefix etc) and return Nothing
  fromBytesC ByteString
_ = Maybe (Point a b baseX baseY f)
forall a. Maybe a
Nothing


  -- Validate via projective form of Weierstrass equation.
  isOnCurve :: Point a b baseX baseY f -> Bool
isOnCurve (Projective f
x f
y f
z) = f
z f -> f -> f
forall a. Num a => a -> a -> a
* f
y f -> Integer -> f
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
2 :: Integer) f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
x f -> Integer -> f
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer) f -> f -> f
forall a. Num a => a -> a -> a
+ 
    Integer -> f
forall a. Num a => Integer -> a
fromInteger (A) * x * z ^ (2 :: Integer) + fromInteger (B) * z ^ f -> f -> f
forall a. Num a => a -> a -> a
(3 :: Integer)


  -- Point negation is flipping y-coordinate.
  negatePt :: Point a b baseX baseY f -> Point a b baseX baseY f
negatePt (Projective f
x f
y f
z) = f -> f -> f -> Point a b baseX baseY f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective f
x (- f
y) f
z


  -- Anything with z=0 is neutral (y cannot be 0)
  neutral :: Point a b baseX baseY f
neutral = f -> f -> f -> Point a b baseX baseY f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective f
0 f
1 f
0


  -- See https://eprint.iacr.org/2015/1060.pdf page 8; Algorithm 1: Complete, projective 
  -- point addition for arbitrary (odd) prime order short Weierstrass curves 
  -- E/Fq : y^2 = x^3 + ax + b. The code has the intermediate additions 'squashed out'
  pointAdd :: Point a b baseX baseY f
-> Point a b baseX baseY f -> Point a b baseX baseY f
pointAdd (Projective f
x1 f
y1 f
z1) (Projective f
x2 f
y2 f
z2) = Point a b baseX baseY f
result
    where
      m0 :: f
m0 = f
x1 f -> f -> f
forall a. Num a => a -> a -> a
* f
x2
      m1 :: f
m1 = f
y1 f -> f -> f
forall a. Num a => a -> a -> a
* f
y2
      m2 :: f
m2 = f
z1 f -> f -> f
forall a. Num a => a -> a -> a
* f
z2
      m3 :: f
m3 = (f
x1 f -> f -> f
forall a. Num a => a -> a -> a
+ f
y1) f -> f -> f
forall a. Num a => a -> a -> a
* (f
x2 f -> f -> f
forall a. Num a => a -> a -> a
+ f
y2)
      m4 :: f
m4 = (f
x1 f -> f -> f
forall a. Num a => a -> a -> a
+ f
z1) f -> f -> f
forall a. Num a => a -> a -> a
* (f
x2 f -> f -> f
forall a. Num a => a -> a -> a
+ f
z2)
      m5 :: f
m5 = (f
y1 f -> f -> f
forall a. Num a => a -> a -> a
+ f
z1) f -> f -> f
forall a. Num a => a -> a -> a
* (f
y2 f -> f -> f
forall a. Num a => a -> a -> a
+ f
z2)
      m6 :: f
m6 = ((Integer -> f
forall a. Num a => Integer -> a
fromInteger (Integer -> f) -> Integer -> f
forall a b. (a -> b) -> a -> b
$ A) :: f) * (- m0 - m2 + m4)
      m7 :: f
m7 = ((Integer -> f
forall a. Num a => Integer -> a
fromInteger (Integer -> f) -> Integer -> f
forall a b. (a -> b) -> a -> b
$ Integer
3 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* B) :: f) * m2
      m8 :: f
m8 = (f
m1 f -> f -> f
forall a. Num a => a -> a -> a
- f
m6 f -> f -> f
forall a. Num a => a -> a -> a
- f
m7) f -> f -> f
forall a. Num a => a -> a -> a
* (f
m1 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m6 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m7)
      m9 :: f
m9 = ((Integer -> f
forall a. Num a => Integer -> a
fromInteger (Integer -> f) -> Integer -> f
forall a b. (a -> b) -> a -> b
$ A) :: f) * m2
      m10 :: f
m10 = ((Integer -> f
forall a. Num a => Integer -> a
fromInteger (Integer -> f) -> Integer -> f
forall a b. (a -> b) -> a -> b
$ Integer
3 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* B) :: f) * (- m0 - m2 + m4)
      m11 :: f
m11 = ((Integer -> f
forall a. Num a => Integer -> a
fromInteger (Integer -> f) -> Integer -> f
forall a b. (a -> b) -> a -> b
$ A) :: f) * (m0 - m9)
      m12 :: f
m12 = (f
m0 f -> f -> f
forall a. Num a => a -> a -> a
* f
3 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m9) f -> f -> f
forall a. Num a => a -> a -> a
* (f
m10 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m11)
      m13 :: f
m13 = (- f
m1 f -> f -> f
forall a. Num a => a -> a -> a
- f
m2 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m5) f -> f -> f
forall a. Num a => a -> a -> a
* (f
m10 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m11)
      m14 :: f
m14 = (- f
m0 f -> f -> f
forall a. Num a => a -> a -> a
- f
m1 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m3) f -> f -> f
forall a. Num a => a -> a -> a
* (f
m1 f -> f -> f
forall a. Num a => a -> a -> a
- f
m6 f -> f -> f
forall a. Num a => a -> a -> a
- f
m7)
      m15 :: f
m15 = (- f
m0 f -> f -> f
forall a. Num a => a -> a -> a
- f
m1 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m3) f -> f -> f
forall a. Num a => a -> a -> a
* (f
m0 f -> f -> f
forall a. Num a => a -> a -> a
* f
3 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m9)
      m16 :: f
m16 = (- f
m1 f -> f -> f
forall a. Num a => a -> a -> a
- f
m2 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m5) f -> f -> f
forall a. Num a => a -> a -> a
* (f
m1 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m6 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m7)
      result :: Point a b baseX baseY f
result = f -> f -> f -> Point a b baseX baseY f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective (-f
m13 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m14) (f
m8 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m12) (f
m15 f -> f -> f
forall a. Num a => a -> a -> a
+ f
m16) :: Point a b baseX baseY f


  -- Serialize a point on the elliptic curve into a ByteString based on section 2.3.3
  -- of https://www.secg.org/sec1-v2.pdf. Only compressed points are supported.
  --toBytesC (Projective xp yp zp)
  toBytesC :: Point a b baseX baseY f -> ByteString
toBytesC Point a b baseX baseY f
pt
    | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Point a b baseX baseY f -> Bool
forall a. CurvePt a => a -> Bool
isOnCurve Point a b baseX baseY f
pt = String -> ByteString
forall a. HasCallStack => String -> a
error String
"trying to serialize point not on curve" 
    | Point a b baseX baseY f -> f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_z Point a b baseX baseY f
pt f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
0 = [Word8] -> ByteString
pack [Word8
0]
    | f -> Integer
forall a. Field a => a -> Integer
sgn0 f
y Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = Word8 -> ByteString -> ByteString
cons Word8
0x02 (f -> ByteString
forall a. Field a => a -> ByteString
toBytesF f
x)
    | Bool
otherwise   = Word8 -> ByteString -> ByteString
cons Word8
0x03 (f -> ByteString
forall a. Field a => a -> ByteString
toBytesF f
x)
    where  -- recover affine coordinates from original projective coordinates
      x :: f
x = Point a b baseX baseY f -> f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_x Point a b baseX baseY f
pt f -> f -> f
forall a. Num a => a -> a -> a
* f -> f
forall a. Field a => a -> a
inv0 (Point a b baseX baseY f -> f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_z Point a b baseX baseY f
pt)
      y :: f
y = Point a b baseX baseY f -> f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_y Point a b baseX baseY f
pt f -> f -> f
forall a. Num a => a -> a -> a
* f -> f
forall a. Field a => a -> a
inv0 (Point a b baseX baseY f -> f
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_z Point a b baseX baseY f
pt)


-- | The `Curve` class provides the elliptic point multiplication operation involving
-- one `CurvePt` point on an elliptic curve and another `Field` field element as the
-- scalar operand. It also provides the `mapToCurveSimpleSwu` which is used in the later
-- stages of hashing-to-curve. It supports both the Pallas and Vesta curve point type.
class (CurvePt a, Field b) => Curve a b where

  -- | The `pointMul` function multiplies a field element by a prime-order curve point. 
  -- This, for example, could be a `PastaCurves.Fq` field element scalar with a 
  -- `PastaCurves.Pallas` elliptic curve point (which happens to use `PastaCurves.Fp` 
  -- co-ordinates). 
  pointMul :: b -> a -> a

  -- The `mapToCurveSimpleSwu` is a simplistic implementation of the Simplified 
  -- Shallue-van de Woestijne-Ulas method maps a field element to a curve point.
  -- See https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-14.html#name-simplified-shallue-van-de-w
  -- It requires A*B != 0 and a special constant Z (see link).
  mapToCurveSimpleSwu :: b -> b -> a


instance (Field f1, Field f2, KnownNat a, KnownNat b, KnownNat baseX, KnownNat baseY) => 
  Curve (Point a b baseX baseY f1) f2 where


  -- Classic double and add algorithm; will add a dedicated point double routine in the future
  pointMul :: f2 -> Point a b baseX baseY f1 -> Point a b baseX baseY f1
pointMul f2
s Point a b baseX baseY f1
pt = f2
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
pointMul' f2
s Point a b baseX baseY f1
pt Point a b baseX baseY f1
forall a. CurvePt a => a
neutral
    where
      pointMul' :: f2 -> Point a b baseX baseY f1 -> Point a b baseX baseY f1 -> Point a b baseX baseY f1
      pointMul' :: f2
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
pointMul' f2
scalar Point a b baseX baseY f1
p1 Point a b baseX baseY f1
accum
        | f2
scalar f2 -> f2 -> Bool
forall a. Eq a => a -> a -> Bool
== f2
0 = Point a b baseX baseY f1
accum  -- scalar is a field element so cannot go 'below zero'
        | f2 -> Integer
forall a. Field a => a -> Integer
sgn0 f2
scalar Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0 = f2
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
pointMul' (f2 -> f2
forall a. Field a => a -> a
shiftR1 f2
scalar) Point a b baseX baseY f1
doublePt (Point a b baseX baseY f1
-> Point a b baseX baseY f1 -> Point a b baseX baseY f1
forall a. CurvePt a => a -> a -> a
pointAdd Point a b baseX baseY f1
accum Point a b baseX baseY f1
p1)
        | f2 -> Integer
forall a. Field a => a -> Integer
sgn0 f2
scalar Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = f2
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
pointMul' (f2 -> f2
forall a. Field a => a -> a
shiftR1 f2
scalar) Point a b baseX baseY f1
doublePt Point a b baseX baseY f1
accum
        | Bool
otherwise = String -> Point a b baseX baseY f1
forall a. HasCallStack => String -> a
error String
"pointMul' pattern match fail (should never happen)"
        where
          doublePt :: Point a b baseX baseY f1
doublePt = Point a b baseX baseY f1
-> Point a b baseX baseY f1 -> Point a b baseX baseY f1
forall a. CurvePt a => a -> a -> a
pointAdd Point a b baseX baseY f1
p1 Point a b baseX baseY f1
p1

  
  -- Z is Pasta-specific (constant is calculated elsewhere)
  mapToCurveSimpleSwu :: f2 -> f2 -> Point a b baseX baseY f1
mapToCurveSimpleSwu f2
fu f2
fz = if A * B /= 0 then result else error "Curve params A*B must not be zero"
    where
      u :: f1
u = (Integer -> f1
forall a. Num a => Integer -> a
fromInteger (Integer -> f1) -> Integer -> f1
forall a b. (a -> b) -> a -> b
$ f2 -> Integer
forall a. Field a => a -> Integer
toI f2
fu)  :: f1  -- pesky type conversion 
      z :: f1
z = (Integer -> f1
forall a. Num a => Integer -> a
fromInteger (Integer -> f1) -> Integer -> f1
forall a b. (a -> b) -> a -> b
$ f2 -> Integer
forall a. Field a => a -> Integer
toI f2
fz)  :: f1
      -- See https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-14.html#section-6.6.2-7
      tv1 :: f1
tv1 = f1 -> f1
forall a. Field a => a -> a
inv0 (f1
z f1 -> Integer -> f1
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
2 :: Integer) f1 -> f1 -> f1
forall a. Num a => a -> a -> a
* f1
u f1 -> Integer -> f1
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
4 :: Integer) f1 -> f1 -> f1
forall a. Num a => a -> a -> a
+ f1
z f1 -> f1 -> f1
forall a. Num a => a -> a -> a
* f1
u f1 -> Integer -> f1
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
2 :: Integer))
      x1a :: f1
x1a = (Integer -> f1
forall a. Num a => Integer -> a
fromInteger ((-Integer
1) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* B) * inv0 (fromInteger (A))f1 -> f1 -> f1
forall a. Num a => a -> a -> a
) * (1 + tv1)
      x1 :: f1
x1 = if f1 -> Integer
forall a. Field a => a -> Integer
toI f1
tv1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 then Integer -> f1
forall a. Num a => Integer -> a
fromInteger (B) * inv0 (z * fromInteger f1 -> f1 -> f1
forall a. Num a => a -> a -> a
(A))  else x1a 
      gx1 :: f1
gx1 = f1
x1 f1 -> Integer -> f1
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer) f1 -> f1 -> f1
forall a. Num a => a -> a -> a
+ Integer -> f1
forall a. Num a => Integer -> a
fromInteger (A) * x1 + fromInteger (B)
      x2 :: f1
x2 = f1
z f1 -> f1 -> f1
forall a. Num a => a -> a -> a
* f1
u f1 -> Integer -> f1
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
2 :: Integer) f1 -> f1 -> f1
forall a. Num a => a -> a -> a
* f1
x1
      gx2 :: f1
gx2 = f1
x2 f1 -> Integer -> f1
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer) f1 -> f1 -> f1
forall a. Num a => a -> a -> a
+ Integer -> f1
forall a. Num a => Integer -> a
fromInteger (A) * x2 + fromInteger (B)
      (f1
x, f1
ya) = if f1 -> Bool
forall a. Field a => a -> Bool
isSqr f1
gx1 then (f1
x1, Maybe f1 -> f1
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe f1 -> f1) -> Maybe f1 -> f1
forall a b. (a -> b) -> a -> b
$ f1 -> Maybe f1
forall a. Field a => a -> Maybe a
sqrt f1
gx1) else (f1
x2, Maybe f1 -> f1
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe f1 -> f1) -> Maybe f1 -> f1
forall a b. (a -> b) -> a -> b
$ f1 -> Maybe f1
forall a. Field a => a -> Maybe a
sqrt f1
gx2)
      y :: f1
y = if f1 -> Integer
forall a. Field a => a -> Integer
sgn0 f1
u Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= f1 -> Integer
forall a. Field a => a -> Integer
sgn0 f1
ya then -f1
ya else f1
ya
      result :: Point a b baseX baseY f1
result = f1 -> f1 -> f1 -> Point a b baseX baseY f1
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective f1
x f1
y f1
1 :: Point a b baseX baseY f1