module Crypto.Paseto.Keys.V3.Internal
  ( isScalarValid
  , encodeScalar
  , ScalarDecodingError (..)
  , renderScalarDecodingError
  , decodeScalar

  , encodePointUncompressed
  , encodePointCompressed
  , UncompressedPointDecodingError (..)
  , renderUncompressedPointDecodingError
  , decodePointUncompressed
  , CompressedPointDecodingError (..)
  , renderCompressedPointDecodingError
  , decodePointCompressed
  , fromPrivateKey
  ) where

import Control.Monad ( when )
import qualified Crypto.Number.Basic as Crypto.Number
import qualified Crypto.Number.ModArithmetic as Crypto.Number
import qualified Crypto.Number.Serialize as Crypto.Number
import qualified Crypto.PubKey.ECC.ECDSA as ECC.ECDSA
import qualified Crypto.PubKey.ECC.Prim as ECC
import qualified Crypto.PubKey.ECC.Types as ECC
import Data.ByteArray ( ScrubbedBytes )
import qualified Data.ByteArray as BA
import Data.ByteString ( ByteString )
import qualified Data.ByteString as BS
import Data.Text ( Text )
import qualified Data.Text as T
import Data.Word ( Word8 )
import Prelude

curveOrderBytes :: ECC.Curve -> Int
curveOrderBytes :: Curve -> Int
curveOrderBytes Curve
curve =
  (Integer -> Int
Crypto.Number.numBits (CurveCommon -> Integer
ECC.ecc_n (CurveCommon -> Integer) -> CurveCommon -> Integer
forall a b. (a -> b) -> a -> b
$ Curve -> CurveCommon
ECC.common_curve Curve
curve) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
7) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8

-- | Whether an elliptic curve scalar value is valid.
isScalarValid :: ECC.Curve -> Integer -> Bool
isScalarValid :: Curve -> Integer -> Bool
isScalarValid Curve
curve Integer
s = Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 Bool -> Bool -> Bool
&& Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
n
  where
    n :: Integer
    n :: Integer
n = (CurveCommon -> Integer
ECC.ecc_n (CurveCommon -> Integer) -> CurveCommon -> Integer
forall a b. (a -> b) -> a -> b
$ Curve -> CurveCommon
ECC.common_curve Curve
curve)

-- | Encode an elliptic curve scalar value.
encodeScalar :: ECC.Curve -> Integer -> ScrubbedBytes
encodeScalar :: Curve -> Integer -> ScrubbedBytes
encodeScalar Curve
curve = Int -> Integer -> ScrubbedBytes
forall ba. ByteArray ba => Int -> Integer -> ba
Crypto.Number.i2ospOf_ (Curve -> Int
curveOrderBytes Curve
curve)

-- | Error decoding a scalar value.
data ScalarDecodingError
  = -- | Invalid scalar length.
    ScalarDecodingInvalidLengthError
      -- | Expected length
      !Int
      -- | Actual length
      !Int
  | -- | Decoded scalar is invalid for the curve.
    ScalarDecodingInvalidError
  deriving stock (Int -> ScalarDecodingError -> ShowS
[ScalarDecodingError] -> ShowS
ScalarDecodingError -> String
(Int -> ScalarDecodingError -> ShowS)
-> (ScalarDecodingError -> String)
-> ([ScalarDecodingError] -> ShowS)
-> Show ScalarDecodingError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ScalarDecodingError -> ShowS
showsPrec :: Int -> ScalarDecodingError -> ShowS
$cshow :: ScalarDecodingError -> String
show :: ScalarDecodingError -> String
$cshowList :: [ScalarDecodingError] -> ShowS
showList :: [ScalarDecodingError] -> ShowS
Show, ScalarDecodingError -> ScalarDecodingError -> Bool
(ScalarDecodingError -> ScalarDecodingError -> Bool)
-> (ScalarDecodingError -> ScalarDecodingError -> Bool)
-> Eq ScalarDecodingError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ScalarDecodingError -> ScalarDecodingError -> Bool
== :: ScalarDecodingError -> ScalarDecodingError -> Bool
$c/= :: ScalarDecodingError -> ScalarDecodingError -> Bool
/= :: ScalarDecodingError -> ScalarDecodingError -> Bool
Eq)

-- | Render a 'ScalarDecodingError' as 'Text'.
renderScalarDecodingError :: ScalarDecodingError -> Text
renderScalarDecodingError :: ScalarDecodingError -> Text
renderScalarDecodingError ScalarDecodingError
err =
  case ScalarDecodingError
err of
    ScalarDecodingInvalidLengthError Int
expected Int
actual ->
      Text
"Decoded scalar value is of length "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
actual)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but was expected to be "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
expected)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"."
    ScalarDecodingError
ScalarDecodingInvalidError -> Text
"Decoded scalar value is invalid for the curve."

-- | Decode an elliptic curve scalar value.
decodeScalar :: ECC.Curve -> ScrubbedBytes -> Either ScalarDecodingError Integer
decodeScalar :: Curve -> ScrubbedBytes -> Either ScalarDecodingError Integer
decodeScalar Curve
curve ScrubbedBytes
bs
  | Int
expectedLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
actualLen = ScalarDecodingError -> Either ScalarDecodingError Integer
forall a b. a -> Either a b
Left (Int -> Int -> ScalarDecodingError
ScalarDecodingInvalidLengthError Int
expectedLen Int
actualLen)
  | Bool
otherwise =
      let s :: Integer
s = ScrubbedBytes -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
Crypto.Number.os2ip ScrubbedBytes
bs
      in if Curve -> Integer -> Bool
isScalarValid Curve
curve Integer
s then Integer -> Either ScalarDecodingError Integer
forall a b. b -> Either a b
Right Integer
s else ScalarDecodingError -> Either ScalarDecodingError Integer
forall a b. a -> Either a b
Left ScalarDecodingError
ScalarDecodingInvalidError
  where
    expectedLen :: Int
    expectedLen :: Int
expectedLen = Curve -> Int
curveOrderBytes Curve
curve

    actualLen :: Int
    actualLen :: Int
actualLen = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ScrubbedBytes
bs

-- | Encode an elliptic curve point into its uncompressed binary format as
-- defined by [SEC 1](https://www.secg.org/sec1-v2.pdf) and
-- [RFC 5480 section 2.2](https://datatracker.ietf.org/doc/html/rfc5480#section-2.2).
--
-- Note that this function will only accept a point on an elliptic curve over
-- 𝔽p (i.e. 'ECC.CurvePrime').
encodePointUncompressed :: ECC.CurvePrime -> ECC.Point -> ByteString
encodePointUncompressed :: CurvePrime -> Point -> ByteString
encodePointUncompressed CurvePrime
curvePrime Point
point
  | Curve -> Point -> Bool
ECC.isPointValid Curve
curve Point
point =
      case Point
point of
        ECC.Point Integer
x Integer
y -> do
          let size :: Int
size = Curve -> Int
ECC.curveSizeBits (CurvePrime -> Curve
ECC.CurveFP CurvePrime
curvePrime) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8
          [ByteString] -> ByteString
BS.concat
            [ Word8 -> ByteString
BS.singleton Word8
0x04
            , Int -> Integer -> ByteString
forall ba. ByteArray ba => Int -> Integer -> ba
Crypto.Number.i2ospOf_ Int
size Integer
x
            , Int -> Integer -> ByteString
forall ba. ByteArray ba => Int -> Integer -> ba
Crypto.Number.i2ospOf_ Int
size Integer
y
            ]
        Point
ECC.PointO -> String -> ByteString
forall a. HasCallStack => String -> a
error String
"encodePointUncompressed: cannot encode point at infinity"
  | Bool
otherwise = String -> ByteString
forall a. HasCallStack => String -> a
error String
"encodePointUncompressed: point is invalid"
  where
    curve :: ECC.Curve
    curve :: Curve
curve = CurvePrime -> Curve
ECC.CurveFP CurvePrime
curvePrime

-- | Encode an elliptic curve point into its compressed binary format as
-- defined by [SEC 1](https://www.secg.org/sec1-v2.pdf).
--
-- Note that this function will only accept a point on an elliptic curve over
-- 𝔽p (i.e. 'ECC.CurvePrime').
--
-- Adapted from
-- [cryptonite issue #302](https://github.com/haskell-crypto/cryptonite/issues/302#issue-531003322).
encodePointCompressed :: ECC.CurvePrime -> ECC.Point -> ByteString
encodePointCompressed :: CurvePrime -> Point -> ByteString
encodePointCompressed CurvePrime
curvePrime Point
point
  | Curve -> Point -> Bool
ECC.isPointValid Curve
curve Point
point =
      case Point
point of
        -- We are using `i2ospOf_` because `curveSizeBits` ensures that
        -- the number won't have more than that many bytes.
        ECC.Point Integer
x Integer
y -> Integer -> ByteString
prefix Integer
y ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Integer -> ByteString
forall ba. ByteArray ba => Int -> Integer -> ba
Crypto.Number.i2ospOf_ (Curve -> Int
ECC.curveSizeBits Curve
curve Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8) Integer
x
        Point
ECC.PointO -> String -> ByteString
forall a. HasCallStack => String -> a
error String
"encodePointCompressed: cannot encode point at infinity"
  | Bool
otherwise = String -> ByteString
forall a. HasCallStack => String -> a
error String
"encodePointCompressed: point is invalid"
  where
    prefix :: Integer -> ByteString
    prefix :: Integer -> ByteString
prefix Integer
y
      | Integer -> Bool
forall a. Integral a => a -> Bool
odd Integer
y = Word8 -> ByteString
BS.singleton Word8
0x03
      | Bool
otherwise = Word8 -> ByteString
BS.singleton Word8
0x02

    curve :: ECC.Curve
    curve :: Curve
curve = CurvePrime -> Curve
ECC.CurveFP CurvePrime
curvePrime

-- | Error decoding an uncompressed elliptic curve point.
data UncompressedPointDecodingError
  = -- | Prefix is not the expected value (@0x04@).
    UncompressedPointDecodingInvalidPrefixError
      -- | Invalid prefix which was encountered.
      !Word8
  | -- | Length of the provided point is invalid.
    UncompressedPointDecodingInvalidLengthError
      -- | Expected length
      !Int
      -- | Actual length
      !Int
  | -- | Point is invalid for the curve.
    UncompressedPointDecodingInvalidPointError !ECC.Point
  deriving stock (Int -> UncompressedPointDecodingError -> ShowS
[UncompressedPointDecodingError] -> ShowS
UncompressedPointDecodingError -> String
(Int -> UncompressedPointDecodingError -> ShowS)
-> (UncompressedPointDecodingError -> String)
-> ([UncompressedPointDecodingError] -> ShowS)
-> Show UncompressedPointDecodingError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> UncompressedPointDecodingError -> ShowS
showsPrec :: Int -> UncompressedPointDecodingError -> ShowS
$cshow :: UncompressedPointDecodingError -> String
show :: UncompressedPointDecodingError -> String
$cshowList :: [UncompressedPointDecodingError] -> ShowS
showList :: [UncompressedPointDecodingError] -> ShowS
Show, UncompressedPointDecodingError
-> UncompressedPointDecodingError -> Bool
(UncompressedPointDecodingError
 -> UncompressedPointDecodingError -> Bool)
-> (UncompressedPointDecodingError
    -> UncompressedPointDecodingError -> Bool)
-> Eq UncompressedPointDecodingError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: UncompressedPointDecodingError
-> UncompressedPointDecodingError -> Bool
== :: UncompressedPointDecodingError
-> UncompressedPointDecodingError -> Bool
$c/= :: UncompressedPointDecodingError
-> UncompressedPointDecodingError -> Bool
/= :: UncompressedPointDecodingError
-> UncompressedPointDecodingError -> Bool
Eq)

-- | Render an 'UncompressedPointDecodingError' as 'Text'.
renderUncompressedPointDecodingError :: UncompressedPointDecodingError -> Text
renderUncompressedPointDecodingError :: UncompressedPointDecodingError -> Text
renderUncompressedPointDecodingError UncompressedPointDecodingError
err =
  case UncompressedPointDecodingError
err of
    UncompressedPointDecodingInvalidPrefixError Word8
invalidPrefix ->
      Text
"Expected prefix "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Word8 -> String
forall a. Show a => a -> String
show (Word8
0x04 :: Word8))
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" for uncompressed point, but encountered "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Word8 -> String
forall a. Show a => a -> String
show Word8
invalidPrefix)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"."
    UncompressedPointDecodingInvalidLengthError Int
expected Int
actual ->
      Text
"Decoded point length is expected to be "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
expected)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but it was "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
actual)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"."
    UncompressedPointDecodingInvalidPointError Point
_ ->
      Text
"Decoded point is invalid for the curve."

-- | Decode an elliptic curve point from its uncompressed binary format as
-- defined by [SEC 1](https://www.secg.org/sec1-v2.pdf) and
-- [RFC 5480 section 2.2](https://datatracker.ietf.org/doc/html/rfc5480#section-2.2).
--
-- Note that this function will only decode a point on an elliptic curve over
-- 𝔽p (i.e. 'ECC.CurvePrime').
decodePointUncompressed :: ECC.CurvePrime -> ByteString -> Either UncompressedPointDecodingError ECC.Point
decodePointUncompressed :: CurvePrime
-> ByteString -> Either UncompressedPointDecodingError Point
decodePointUncompressed CurvePrime
curvePrime ByteString
bs = do
  let expectedPointLen :: Int
      expectedPointLen :: Int
expectedPointLen = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ((Curve -> Int
ECC.curveSizeBits (CurvePrime -> Curve
ECC.CurveFP CurvePrime
curvePrime) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2)

      actualPointLen :: Int
      actualPointLen :: Int
actualPointLen = ByteString -> Int
BS.length ByteString
bs

  Bool
-> Either UncompressedPointDecodingError ()
-> Either UncompressedPointDecodingError ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
    (Int
expectedPointLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
actualPointLen)
    (UncompressedPointDecodingError
-> Either UncompressedPointDecodingError ()
forall a b. a -> Either a b
Left (UncompressedPointDecodingError
 -> Either UncompressedPointDecodingError ())
-> UncompressedPointDecodingError
-> Either UncompressedPointDecodingError ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> UncompressedPointDecodingError
UncompressedPointDecodingInvalidLengthError Int
expectedPointLen Int
actualPointLen)

  case ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
bs of
    Maybe (Word8, ByteString)
Nothing -> UncompressedPointDecodingError
-> Either UncompressedPointDecodingError Point
forall a b. a -> Either a b
Left (Int -> Int -> UncompressedPointDecodingError
UncompressedPointDecodingInvalidLengthError Int
expectedPointLen Int
0)
    Just (Word8
prefix, ByteString
rest)
      | Word8
prefix Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0x04 ->
          let (ByteString
xBs, ByteString
yBs) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
actualPointLen ByteString
rest
              x :: Integer
x = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
Crypto.Number.os2ip ByteString
xBs
              y :: Integer
y = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
Crypto.Number.os2ip ByteString
yBs
              point :: Point
point = Integer -> Integer -> Point
ECC.Point Integer
x Integer
y
          in if Curve -> Point -> Bool
ECC.isPointValid (CurvePrime -> Curve
ECC.CurveFP CurvePrime
curvePrime) Point
point
            then Point -> Either UncompressedPointDecodingError Point
forall a b. b -> Either a b
Right Point
point
            else UncompressedPointDecodingError
-> Either UncompressedPointDecodingError Point
forall a b. a -> Either a b
Left (Point -> UncompressedPointDecodingError
UncompressedPointDecodingInvalidPointError Point
point)
      | Bool
otherwise -> UncompressedPointDecodingError
-> Either UncompressedPointDecodingError Point
forall a b. a -> Either a b
Left (Word8 -> UncompressedPointDecodingError
UncompressedPointDecodingInvalidPrefixError Word8
prefix)

-- | Error decoding a compressed elliptic curve point.
data CompressedPointDecodingError
  = -- | Prefix is not either of the expected values (@0x02@ or @0x03@).
    CompressedPointDecodingInvalidPrefixError
      -- | Invalid prefix which was encountered.
      !Word8
  | -- | Length of the provided compressed point is invalid.
    CompressedPointDecodingInvalidLengthError
      -- | Expected length
      !Int
      -- | Actual length
      !Int
  | -- | Failed to find the modular square root of a value.
    CompressedPointDecodingModularSquareRootError
  | -- | Point is invalid for the curve.
    CompressedPointDecodingInvalidPointError !ECC.Point
  deriving stock (Int -> CompressedPointDecodingError -> ShowS
[CompressedPointDecodingError] -> ShowS
CompressedPointDecodingError -> String
(Int -> CompressedPointDecodingError -> ShowS)
-> (CompressedPointDecodingError -> String)
-> ([CompressedPointDecodingError] -> ShowS)
-> Show CompressedPointDecodingError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CompressedPointDecodingError -> ShowS
showsPrec :: Int -> CompressedPointDecodingError -> ShowS
$cshow :: CompressedPointDecodingError -> String
show :: CompressedPointDecodingError -> String
$cshowList :: [CompressedPointDecodingError] -> ShowS
showList :: [CompressedPointDecodingError] -> ShowS
Show, CompressedPointDecodingError
-> CompressedPointDecodingError -> Bool
(CompressedPointDecodingError
 -> CompressedPointDecodingError -> Bool)
-> (CompressedPointDecodingError
    -> CompressedPointDecodingError -> Bool)
-> Eq CompressedPointDecodingError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CompressedPointDecodingError
-> CompressedPointDecodingError -> Bool
== :: CompressedPointDecodingError
-> CompressedPointDecodingError -> Bool
$c/= :: CompressedPointDecodingError
-> CompressedPointDecodingError -> Bool
/= :: CompressedPointDecodingError
-> CompressedPointDecodingError -> Bool
Eq)

-- | Render an 'CompressedPointDecodingError' as 'Text'.
renderCompressedPointDecodingError :: CompressedPointDecodingError -> Text
renderCompressedPointDecodingError :: CompressedPointDecodingError -> Text
renderCompressedPointDecodingError CompressedPointDecodingError
err =
  case CompressedPointDecodingError
err of
    CompressedPointDecodingInvalidPrefixError Word8
invalidPrefix ->
      Text
"Expected prefix of either "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Word8 -> String
forall a. Show a => a -> String
show (Word8
0x02 :: Word8))
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" or "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Word8 -> String
forall a. Show a => a -> String
show (Word8
0x03 :: Word8))
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" for compressed point, but encountered "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Word8 -> String
forall a. Show a => a -> String
show Word8
invalidPrefix)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"."
    CompressedPointDecodingInvalidLengthError Int
expected Int
actual ->
      Text
"Decoded point length is expected to be "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
expected)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", but it was "
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Int -> String
forall a. Show a => a -> String
show Int
actual)
        Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"."
    CompressedPointDecodingError
CompressedPointDecodingModularSquareRootError ->
      Text
"Failed to recover the x-coordinate from the compressed point."
    CompressedPointDecodingInvalidPointError Point
_ ->
      Text
"Decoded point is invalid for the curve."

data EvenOrOddY
  = EvenY
  | OddY

toEvenOrOddY :: Word8 -> Maybe EvenOrOddY
toEvenOrOddY :: Word8 -> Maybe EvenOrOddY
toEvenOrOddY Word8
0x02 = EvenOrOddY -> Maybe EvenOrOddY
forall a. a -> Maybe a
Just EvenOrOddY
EvenY
toEvenOrOddY Word8
0x03 = EvenOrOddY -> Maybe EvenOrOddY
forall a. a -> Maybe a
Just EvenOrOddY
OddY
toEvenOrOddY Word8
_ = Maybe EvenOrOddY
forall a. Maybe a
Nothing

-- | Decode an elliptic curve point from its compressed binary format as
-- defined by [SEC 1](https://www.secg.org/sec1-v2.pdf) and
-- [RFC 5480 section 2.2](https://datatracker.ietf.org/doc/html/rfc5480#section-2.2).
--
-- Note that this function will only decode a point on an elliptic curve over
-- 𝔽p (i.e. 'ECC.CurvePrime').
--
-- Thanks to
-- [cryptonite PR #303](https://github.com/haskell-crypto/cryptonite/pull/303),
-- there's a function that we can use to compute a square root modulo a prime
-- number ('Crypto.Number.squareRoot').
decodePointCompressed :: ECC.CurvePrime -> ByteString -> Either CompressedPointDecodingError ECC.Point
decodePointCompressed :: CurvePrime
-> ByteString -> Either CompressedPointDecodingError Point
decodePointCompressed curvePrime :: CurvePrime
curvePrime@(ECC.CurvePrime Integer
p CurveCommon
curveCommon) ByteString
bs = do
  let expectedCompressedPointLen :: Int
      expectedCompressedPointLen :: Int
expectedCompressedPointLen = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Curve -> Int
ECC.curveSizeBits (CurvePrime -> Curve
ECC.CurveFP CurvePrime
curvePrime) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
8)

      actualCompressedPointLen :: Int
      actualCompressedPointLen :: Int
actualCompressedPointLen = ByteString -> Int
BS.length ByteString
bs

  Bool
-> Either CompressedPointDecodingError ()
-> Either CompressedPointDecodingError ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
    (Int
expectedCompressedPointLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
actualCompressedPointLen)
    (CompressedPointDecodingError
-> Either CompressedPointDecodingError ()
forall a b. a -> Either a b
Left (CompressedPointDecodingError
 -> Either CompressedPointDecodingError ())
-> CompressedPointDecodingError
-> Either CompressedPointDecodingError ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> CompressedPointDecodingError
CompressedPointDecodingInvalidLengthError Int
expectedCompressedPointLen Int
actualCompressedPointLen)

  case ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
bs of
    Just (Word8
prefix, ByteString
rest) ->
      case Word8 -> Maybe EvenOrOddY
toEvenOrOddY Word8
prefix of
        Maybe EvenOrOddY
Nothing -> CompressedPointDecodingError
-> Either CompressedPointDecodingError Point
forall a b. a -> Either a b
Left (Word8 -> CompressedPointDecodingError
CompressedPointDecodingInvalidPrefixError Word8
prefix)
        Just EvenOrOddY
evenOrOddY -> do
          let x :: Integer
              x :: Integer
x = ByteString -> Integer
forall ba. ByteArrayAccess ba => ba -> Integer
Crypto.Number.os2ip ByteString
rest

              b :: Integer
              b :: Integer
b = CurveCommon -> Integer
ECC.ecc_b CurveCommon
curveCommon

          Integer
y <-
            case Integer -> Integer -> Maybe Integer
Crypto.Number.squareRoot Integer
p ((Integer
x Integer -> Integer -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer)) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- (Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
3) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
b) of
              Maybe Integer
Nothing -> CompressedPointDecodingError
-> Either CompressedPointDecodingError Integer
forall a b. a -> Either a b
Left CompressedPointDecodingError
CompressedPointDecodingModularSquareRootError
              Just Integer
y' ->
                case (EvenOrOddY
evenOrOddY, Integer -> Bool
forall a. Integral a => a -> Bool
odd Integer
y') of
                  (EvenOrOddY
EvenY, Bool
True) -> Integer -> Either CompressedPointDecodingError Integer
forall a b. b -> Either a b
Right (Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
y')
                  (EvenOrOddY
OddY, Bool
False) -> Integer -> Either CompressedPointDecodingError Integer
forall a b. b -> Either a b
Right (Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
y')
                  (EvenOrOddY, Bool)
_ -> Integer -> Either CompressedPointDecodingError Integer
forall a b. b -> Either a b
Right Integer
y'

          let point :: ECC.Point
              point :: Point
point = Integer -> Integer -> Point
ECC.Point Integer
x Integer
y
          if Curve -> Point -> Bool
ECC.isPointValid (CurvePrime -> Curve
ECC.CurveFP CurvePrime
curvePrime) Point
point
            then Point -> Either CompressedPointDecodingError Point
forall a b. b -> Either a b
Right Point
point
            else CompressedPointDecodingError
-> Either CompressedPointDecodingError Point
forall a b. a -> Either a b
Left (Point -> CompressedPointDecodingError
CompressedPointDecodingInvalidPointError Point
point)
    Maybe (Word8, ByteString)
Nothing ->
      -- This should be impossible since we checked the length beforehand.
      CompressedPointDecodingError
-> Either CompressedPointDecodingError Point
forall a b. a -> Either a b
Left (Int -> Int -> CompressedPointDecodingError
CompressedPointDecodingInvalidLengthError Int
expectedCompressedPointLen Int
0)

-- | Construct the 'ECC.ECDSA.PublicKey' which corresponds to a given
-- 'ECC.ECDSA.PrivateKey'.
fromPrivateKey :: ECC.ECDSA.PrivateKey -> ECC.ECDSA.PublicKey
fromPrivateKey :: PrivateKey -> PublicKey
fromPrivateKey (ECC.ECDSA.PrivateKey Curve
curve Integer
d) =
  Curve -> Point -> PublicKey
ECC.ECDSA.PublicKey Curve
curve (Curve -> Integer -> Point
ECC.pointBaseMul Curve
curve Integer
d)