{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeFamilies #-}
{-|
Module: Crypto.Spake2.Groups.IntegerGroup
Description: Multiplicative group of integers modulo \(n\)
-}
module Crypto.Spake2.Groups.IntegerGroup
  ( IntegerGroup(..)
  , makeIntegerGroup
  , i1024
  ) where

import Protolude hiding (group, length)

import Crypto.Error (CryptoFailable(..), CryptoError(..))
import Crypto.Number.Basic (numBits)
import Crypto.Number.Generate (generateMax)
import Crypto.Number.ModArithmetic (expSafe)

import Crypto.Spake2.Group
  ( AbelianGroup(..)
  , Group(..)
  , KeyPair(..)
  , elementSizeBytes
  )
import Crypto.Spake2.Util
  ( expandArbitraryElementSeed
  , bytesToNumber
  , unsafeNumberToBytes
  )

-- | A finite group of integers with respect to multiplication modulo the group order.
--
-- Construct with 'makeIntegerGroup'.
data IntegerGroup
  = IntegerGroup
  { IntegerGroup -> Integer
order :: !Integer
  , IntegerGroup -> Integer
subgroupOrder :: !Integer
  , IntegerGroup -> Integer
generator :: !Integer
  } deriving (IntegerGroup -> IntegerGroup -> Bool
(IntegerGroup -> IntegerGroup -> Bool)
-> (IntegerGroup -> IntegerGroup -> Bool) -> Eq IntegerGroup
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IntegerGroup -> IntegerGroup -> Bool
$c/= :: IntegerGroup -> IntegerGroup -> Bool
== :: IntegerGroup -> IntegerGroup -> Bool
$c== :: IntegerGroup -> IntegerGroup -> Bool
Eq, Int -> IntegerGroup -> ShowS
[IntegerGroup] -> ShowS
IntegerGroup -> String
(Int -> IntegerGroup -> ShowS)
-> (IntegerGroup -> String)
-> ([IntegerGroup] -> ShowS)
-> Show IntegerGroup
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IntegerGroup] -> ShowS
$cshowList :: [IntegerGroup] -> ShowS
show :: IntegerGroup -> String
$cshow :: IntegerGroup -> String
showsPrec :: Int -> IntegerGroup -> ShowS
$cshowsPrec :: Int -> IntegerGroup -> ShowS
Show)

-- | Construct an 'IntegerGroup'.
--
-- Will fail if generator is '1',
-- since having the identity for a generator means the subgroup is the entire group.
--
-- TODO: Find other things to check for validity.
makeIntegerGroup :: Integer -> Integer -> Integer -> Maybe IntegerGroup
makeIntegerGroup :: Integer -> Integer -> Integer -> Maybe IntegerGroup
makeIntegerGroup Integer
_ Integer
_ Integer
1 = Maybe IntegerGroup
forall a. Maybe a
Nothing
makeIntegerGroup Integer
order Integer
subgroupOrder Integer
generator = IntegerGroup -> Maybe IntegerGroup
forall a. a -> Maybe a
Just (Integer -> Integer -> Integer -> IntegerGroup
IntegerGroup Integer
order Integer
subgroupOrder Integer
generator)


instance Group IntegerGroup where
  type Element IntegerGroup = Integer

  elementAdd :: IntegerGroup
-> Element IntegerGroup
-> Element IntegerGroup
-> Element IntegerGroup
elementAdd IntegerGroup
group Element IntegerGroup
x Element IntegerGroup
y = (Integer
Element IntegerGroup
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
Element IntegerGroup
y) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` IntegerGroup -> Integer
order IntegerGroup
group
  -- At a guess, negation is scalar multiplication where the scalar is -1
  elementNegate :: IntegerGroup -> Element IntegerGroup -> Element IntegerGroup
elementNegate IntegerGroup
group Element IntegerGroup
x = Integer -> Integer -> Integer -> Integer
expSafe Integer
Element IntegerGroup
x (IntegerGroup -> Integer
subgroupOrder IntegerGroup
group Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1) (IntegerGroup -> Integer
order IntegerGroup
group)
  groupIdentity :: IntegerGroup -> Element IntegerGroup
groupIdentity IntegerGroup
_ = Element IntegerGroup
1
  encodeElement :: IntegerGroup -> Element IntegerGroup -> bytes
encodeElement IntegerGroup
group = Int -> Integer -> bytes
forall bytes. ByteArray bytes => Int -> Integer -> bytes
unsafeNumberToBytes (IntegerGroup -> Int
forall group. Group group => group -> Int
elementSizeBytes IntegerGroup
group)
  decodeElement :: IntegerGroup -> bytes -> CryptoFailable (Element IntegerGroup)
decodeElement IntegerGroup
group bytes
bytes =
    case bytes -> Integer
forall bytes. ByteArrayAccess bytes => bytes -> Integer
bytesToNumber bytes
bytes of
      Integer
x
        | Integer
x Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0 Bool -> Bool -> Bool
|| Integer
x Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= IntegerGroup -> Integer
order IntegerGroup
group -> CryptoError -> CryptoFailable Integer
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_PointSizeInvalid
        | Integer -> Integer -> Integer -> Integer
expSafe Integer
x (IntegerGroup -> Integer
subgroupOrder IntegerGroup
group) (IntegerGroup -> Integer
order IntegerGroup
group) Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= IntegerGroup -> Element IntegerGroup
forall group. Group group => group -> Element group
groupIdentity IntegerGroup
group -> CryptoError -> CryptoFailable Integer
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_PointCoordinatesInvalid
        | Bool
otherwise -> Integer -> CryptoFailable Integer
forall a. a -> CryptoFailable a
CryptoPassed Integer
x
  elementSizeBits :: IntegerGroup -> Int
elementSizeBits IntegerGroup
group = Integer -> Int
numBits (IntegerGroup -> Integer
order IntegerGroup
group)
  arbitraryElement :: IntegerGroup -> bytes -> Element IntegerGroup
arbitraryElement IntegerGroup
group bytes
seed =
    let processedSeed :: ByteString
processedSeed = bytes -> Int -> ByteString
forall ikm out.
(ByteArrayAccess ikm, ByteArray out) =>
ikm -> Int -> out
expandArbitraryElementSeed bytes
seed (IntegerGroup -> Int
forall group. Group group => group -> Int
elementSizeBytes IntegerGroup
group) :: ByteString
        p :: Integer
p = IntegerGroup -> Integer
order IntegerGroup
group
        q :: Integer
q = IntegerGroup -> Integer
subgroupOrder IntegerGroup
group
        r :: Integer
r = (Integer
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
q
        h :: Integer
h = ByteString -> Integer
forall bytes. ByteArrayAccess bytes => bytes -> Integer
bytesToNumber ByteString
processedSeed Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p
    in Integer -> Integer -> Integer -> Integer
expSafe Integer
h Integer
r Integer
p


instance AbelianGroup IntegerGroup where
  type Scalar IntegerGroup = Integer

  scalarMultiply :: IntegerGroup
-> Scalar IntegerGroup
-> Element IntegerGroup
-> Element IntegerGroup
scalarMultiply IntegerGroup
group Scalar IntegerGroup
n Element IntegerGroup
x = Integer -> Integer -> Integer -> Integer
expSafe Integer
Element IntegerGroup
x (Integer
Scalar IntegerGroup
n Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` IntegerGroup -> Integer
subgroupOrder IntegerGroup
group) (IntegerGroup -> Integer
order IntegerGroup
group)
  integerToScalar :: IntegerGroup -> Integer -> Scalar IntegerGroup
integerToScalar IntegerGroup
group Integer
x = Integer
x Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` IntegerGroup -> Integer
subgroupOrder IntegerGroup
group
  scalarToInteger :: IntegerGroup -> Scalar IntegerGroup -> Integer
scalarToInteger IntegerGroup
_ Scalar IntegerGroup
n = Integer
Scalar IntegerGroup
n

  generateElement :: IntegerGroup -> randomly (KeyPair IntegerGroup)
generateElement IntegerGroup
group = do
    Integer
scalar <- Integer -> randomly Integer
forall (m :: * -> *). MonadRandom m => Integer -> m Integer
generateMax (IntegerGroup -> Integer
subgroupOrder IntegerGroup
group)
    let element :: Element IntegerGroup
element = IntegerGroup
-> Scalar IntegerGroup
-> Element IntegerGroup
-> Element IntegerGroup
forall group.
AbelianGroup group =>
group -> Scalar group -> Element group -> Element group
scalarMultiply IntegerGroup
group Integer
Scalar IntegerGroup
scalar (IntegerGroup -> Integer
generator IntegerGroup
group)
    KeyPair IntegerGroup -> randomly (KeyPair IntegerGroup)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Element IntegerGroup -> Scalar IntegerGroup -> KeyPair IntegerGroup
forall group. Element group -> Scalar group -> KeyPair group
KeyPair Element IntegerGroup
element Integer
Scalar IntegerGroup
scalar)
  scalarSizeBits :: IntegerGroup -> Int
scalarSizeBits IntegerGroup
group = Integer -> Int
numBits (IntegerGroup -> Integer
subgroupOrder IntegerGroup
group)


-- | 1024 bit integer group.
--
-- Originally from http://haofeng66.googlepages.com/JPAKEDemo.java,
-- via [python-spake2](https://github.com/warner/python-spake2).
i1024 :: IntegerGroup
i1024 :: IntegerGroup
i1024 =
  IntegerGroup :: Integer -> Integer -> Integer -> IntegerGroup
IntegerGroup
  { order :: Integer
order = Integer
0xE0A67598CD1B763BC98C8ABB333E5DDA0CD3AA0E5E1FB5BA8A7B4EABC10BA338FAE06DD4B90FDA70D7CF0CB0C638BE3341BEC0AF8A7330A3307DED2299A0EE606DF035177A239C34A912C202AA5F83B9C4A7CF0235B5316BFC6EFB9A248411258B30B839AF172440F32563056CB67A861158DDD90E6A894C72A5BBEF9E286C6B
  , subgroupOrder :: Integer
subgroupOrder = Integer
0xE950511EAB424B9A19A2AEB4E159B7844C589C4F
  , generator :: Integer
generator = Integer
0xD29D5121B0423C2769AB21843E5A3240FF19CACC792264E3BB6BE4F78EDD1B15C4DFF7F1D905431F0AB16790E1F773B5CE01C804E509066A9919F5195F4ABC58189FD9FF987389CB5BEDF21B4DAB4F8B76A055FFE2770988FE2EC2DE11AD92219F0B351869AC24DA3D7BA87011A701CE8EE7BFE49486ED4527B7186CA4610A75
  }