-- |
-- Module      : Crypto.Store.PKCS8.EC
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- Additional EC utilities.
module Crypto.Store.PKCS8.EC
    ( numBytes
    , curveSizeBytes
    , curveOrderBytes
    , curveNameOID
    , getSerializedPoint
    , module Data.X509.EC
    ) where

import           Data.ASN1.OID
import qualified Data.ByteString as B
import           Data.Maybe (fromMaybe)

import Data.X509
import Data.X509.EC

import Crypto.Number.Basic (numBits, numBytes)
import Crypto.Number.Serialize (i2ospOf_)
import Crypto.PubKey.ECC.Prim
import Crypto.PubKey.ECC.Types

import Crypto.Store.CMS.Util

-- | Number of bytes necessary to serialize n bits.
bitsToBytes :: Int -> Int
bitsToBytes :: Int -> Int
bitsToBytes Int
n = (Int
n forall a. Num a => a -> a -> a
+ Int
7) forall a. Integral a => a -> a -> a
`div` Int
8

-- | Number of bytes to serialize a field element.
curveSizeBytes :: Curve -> Int
curveSizeBytes :: Curve -> Int
curveSizeBytes = Int -> Int
bitsToBytes forall b c a. (b -> c) -> (a -> b) -> a -> c
. Curve -> Int
curveSizeBits

-- | Number of bytes to serialize a scalar.
curveOrderBytes :: Curve -> Int
curveOrderBytes :: Curve -> Int
curveOrderBytes = Int -> Int
bitsToBytes forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Int
numBits forall b c a. (b -> c) -> (a -> b) -> a -> c
. CurveCommon -> Integer
ecc_n forall b c a. (b -> c) -> (a -> b) -> a -> c
. Curve -> CurveCommon
common_curve

-- | Transform a private scalar to a point in uncompressed format.
getSerializedPoint :: Curve -> PrivateNumber -> SerializedPoint
getSerializedPoint :: Curve -> Integer -> SerializedPoint
getSerializedPoint Curve
curve Integer
d = ByteString -> SerializedPoint
SerializedPoint (Point -> ByteString
serializePoint Point
pt)
  where
    pt :: Point
pt = Curve -> Integer -> Point
pointBaseMul Curve
curve Integer
d
    bs :: Integer -> ByteString
bs = forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ (Curve -> Int
curveSizeBytes Curve
curve)

    serializePoint :: Point -> ByteString
serializePoint Point
PointO      = Word8 -> ByteString
B.singleton Word8
0
    serializePoint (Point Integer
x Integer
y) = Word8 -> ByteString -> ByteString
B.cons Word8
4 (ByteString -> ByteString -> ByteString
B.append (Integer -> ByteString
bs Integer
x) (Integer -> ByteString
bs Integer
y))

-- | Return the OID associated to a curve name.
curveNameOID :: CurveName -> OID
curveNameOID :: CurveName -> OID
curveNameOID CurveName
name =
    forall a. a -> Maybe a -> a
fromMaybe (forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"PKCS8: OID unknown for EC curve " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show CurveName
name)
        (forall a. Eq a => OIDTable a -> a -> Maybe OID
lookupOID OIDTable CurveName
curvesOIDTable CurveName
name)

curvesOIDTable :: OIDTable CurveName
curvesOIDTable :: OIDTable CurveName
curvesOIDTable =
    [ (CurveName
SEC_p112r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
6])
    , (CurveName
SEC_p112r2, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
7])
    , (CurveName
SEC_p128r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
28])
    , (CurveName
SEC_p128r2, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
29])
    , (CurveName
SEC_p160k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
9])
    , (CurveName
SEC_p160r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
8])
    , (CurveName
SEC_p160r2, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
30])
    , (CurveName
SEC_p192k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
31])
    , (CurveName
SEC_p192r1, [Integer
1,Integer
2,Integer
840,Integer
10045,Integer
3,Integer
1,Integer
1])
    , (CurveName
SEC_p224k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
32])
    , (CurveName
SEC_p224r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
33])
    , (CurveName
SEC_p256k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
10])
    , (CurveName
SEC_p256r1, [Integer
1,Integer
2,Integer
840,Integer
10045,Integer
3,Integer
1,Integer
7])
    , (CurveName
SEC_p384r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
34])
    , (CurveName
SEC_p521r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
35])
    , (CurveName
SEC_t113r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
4])
    , (CurveName
SEC_t113r2, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
5])
    , (CurveName
SEC_t131r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
22])
    , (CurveName
SEC_t131r2, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
23])
    , (CurveName
SEC_t163k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
1])
    , (CurveName
SEC_t163r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
2])
    , (CurveName
SEC_t163r2, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
15])
    , (CurveName
SEC_t193r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
24])
    , (CurveName
SEC_t193r2, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
25])
    , (CurveName
SEC_t233k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
26])
    , (CurveName
SEC_t233r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
27])
    , (CurveName
SEC_t239k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
3])
    , (CurveName
SEC_t283k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
16])
    , (CurveName
SEC_t283r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
17])
    , (CurveName
SEC_t409k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
36])
    , (CurveName
SEC_t409r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
37])
    , (CurveName
SEC_t571k1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
38])
    , (CurveName
SEC_t571r1, [Integer
1,Integer
3,Integer
132,Integer
0,Integer
39])
    ]