{-# OPTIONS_GHC -fno-warn-orphans #-}
module Voting.Protocol.Arithmetic
 ( module Voting.Protocol.Arithmetic
 , Natural
 , Random.RandomGen
 ) where

import Control.Arrow (first)
import Control.DeepSeq (NFData)
import Control.Monad (Monad(..))
import Data.Bits
import Data.Bool
import Data.Eq (Eq(..))
import Data.Foldable (Foldable, foldl')
import Data.Function (($), (.))
import Data.Functor ((<$>))
import Data.Int (Int)
import Data.Maybe (Maybe(..))
import Data.Ord (Ord(..))
import Data.Semigroup (Semigroup(..))
import Data.String (String, IsString(..))
import Numeric.Natural (Natural)
import Prelude (Integer, Integral(..), fromIntegral, Enum(..))
import Text.Show (Show(..))
import qualified Control.Monad.Trans.State.Strict as S
import qualified Crypto.Hash as Crypto
import qualified Data.ByteArray as ByteArray
import qualified Data.ByteString as BS
import qualified Data.List as List
import qualified Prelude as Num
import qualified System.Random as Random

-- * Type 'F'
-- | The type of the elements of a 'PrimeField'.
--
-- A field must satisfy the following properties:
--
-- * @(f, ('+'), 'zero')@ forms an abelian group,
--   called the 'Additive' group of 'f'.
--
-- * @('NonNull' f, ('*'), 'one')@ forms an abelian group,
--   called the 'Multiplicative' group of 'f'.
--
-- * ('*') is associative:
--   @(a'*'b)'*'c == a'*'(b'*'c)@ and
--   @a'*'(b'*'c) == (a'*'b)'*'c@.
--
-- * ('*') and ('+') are both commutative:
--   @a'*'b == b'*'a@ and
--   @a'+'b == b'+'a@
--
-- * ('*') and ('+') are both left and right distributive:
--   @a'*'(b'+'c) == (a'*'b) '+' (a'*'c)@ and
--   @(a'+'b)'*'c == (a'*'c) '+' (b'*'c)@
--
-- The 'Natural' is always within @[0..'fieldCharac'-1]@.
newtype F p = F { unF :: Natural }
 deriving (Eq,Ord,Show,NFData)

instance PrimeField p => FromNatural (F p) where
        fromNatural i = F (abs (i `mod` fieldCharac @p))
                where abs x | x < 0 = x + fieldCharac @p
                            | otherwise = x
instance ToNatural (F p) where
        nat = unF

instance PrimeField p => Additive (F p) where
        zero = F 0
        F x + F y = F ((x + y) `mod` fieldCharac @p)
instance PrimeField p => Negable (F p) where
        neg (F x) | x == 0 = zero
                  | otherwise = F (fromIntegral (Num.negate (toInteger x) + toInteger (fieldCharac @p)))
instance PrimeField p => Multiplicative (F p) where
        one = F 1
        -- | Because 'fieldCharac' is prime,
        -- all elements of the field are invertible modulo 'fieldCharac'.
        F x * F y = F ((x * y) `mod` fieldCharac @p)
instance PrimeField p => Random.Random (F p) where
        randomR (F lo, F hi) =
                first (F . fromIntegral) .
                Random.randomR
                 ( 0`max`toInteger lo
                 , toInteger hi`min`(toInteger (fieldCharac @p) - 1))
        random = first (F . fromIntegral) . Random.randomR (0, toInteger (fieldCharac @p) - 1)

-- ** Class 'PrimeField'
-- | Parameter for a prime field.
class PrimeField p where
        -- | The prime number characteristic of a 'PrimeField'.
        --
        -- ElGamal's hardness to decrypt requires a large prime number
        -- to form the 'Multiplicative' 'SubGroup'.
        fieldCharac :: Natural

-- ** Class 'Additive'
class Additive a where
        zero :: a
        (+) :: a -> a -> a; infixl 6 +
        sum :: Foldable f => f a -> a
        sum = foldl' (+) zero
instance Additive Natural where
        zero = 0
        (+)  = (Num.+)
instance Additive Integer where
        zero = 0
        (+)  = (Num.+)
instance Additive Int where
        zero = 0
        (+)  = (Num.+)

-- *** Class 'Negable'
class Additive a => Negable a where
        neg :: a -> a
        (-) :: a -> a -> a; infixl 6 -
        x-y = x + neg y
instance Negable Integer where
        neg  = Num.negate
instance Negable Int where
        neg  = Num.negate

-- ** Class 'Multiplicative'
class Multiplicative a where
        one :: a
        (*) :: a -> a -> a; infixl 7 *
instance Multiplicative Natural where
        one = 1
        (*) = (Num.*)
instance Multiplicative Integer where
        one = 1
        (*) = (Num.*)
instance Multiplicative Int where
        one = 1
        (*) = (Num.*)

-- ** Class 'Invertible'
class Multiplicative a => Invertible a where
        inv :: a -> a
        (/) :: a -> a -> a; infixl 7 /
        x/y = x * inv y

-- * Type 'G'
-- | The type of the elements of a 'Multiplicative' 'SubGroup' of a 'PrimeField'.
newtype G q = G { unG :: F (P q) }
 deriving (Eq,Ord,Show,NFData)

instance PrimeField (P q) => FromNatural (G q) where
        fromNatural = G . fromNatural
instance ToNatural (G q) where
        nat = unF . unG

instance (SubGroup q, Multiplicative (F (P q))) => Multiplicative (G q) where
        one = G one
        G x * G y = G (x * y)
instance (SubGroup q, Multiplicative (F (P q))) => Invertible (G q) where
        -- | NOTE: add 'groupOrder' so the exponent given to (^) is positive.
        inv = (^ E (neg one + groupOrder @q))

-- ** Class 'SubGroup'
-- | A 'SubGroup' of a 'Multiplicative' group of a 'PrimeField'.
-- Used for signing (Schnorr) and encrypting (ElGamal).
class
 ( PrimeField (P q)
 , Multiplicative (F (P q))
 ) => SubGroup q where
        -- | Setting 'q' determines 'p', equals to @'P' q@.
        type P q :: *
        -- | A generator of the 'SubGroup'.
        -- NOTE: since @F p@ is a 'PrimeField',
        -- the 'Multiplicative' 'SubGroup' is cyclic,
        -- and there are phi('fieldCharac'-1) many choices for the generator of the group,
        -- where phi is the Euler totient function.
        groupGen :: G q
        -- | The order of the 'SubGroup'.
        --
        -- WARNING: 'groupOrder' MUST be a prime number dividing @('fieldCharac'-1)@
        -- to ensure that ElGamal is secure in terms of the DDH assumption.
        groupOrder :: F (P q)

        -- | 'groupGenInverses' returns the infinite list
        -- of 'inv'erse powers of 'groupGen':
        -- @['groupGen' '^' 'neg' i | i <- [0..]]@,
        -- but by computing each value from the previous one.
        --
        -- NOTE: 'groupGenInverses' is in the 'SubGroup' class in order to keep
        -- computed terms in memory across calls to 'groupGenInverses'.
        --
        -- Used by 'intervalDisjunctions'.
        groupGenInverses :: [G q]
        groupGenInverses = go one
                where
                go g = g : go (g * invGen)
                invGen = inv groupGen

groupGenPowers :: SubGroup q => [G q]
groupGenPowers = go one
        where go g = g : go (g * groupGen)

-- | @('hash' bs gs)@ returns as a number in 'E'
-- the SHA256 of the given 'BS.ByteString' 'bs'
-- prefixing the decimal representation of given 'SubGroup' elements 'gs',
-- with a comma (",") intercalated between them.
--
-- NOTE: to avoid any collision when the 'hash' function is used in different contexts,
-- a message 'gs' is actually prefixed by a 'bs' indicating the context.
--
-- Used by 'proveEncryption' and 'verifyEncryption',
-- where the 'bs' usually contains the 'statement' to be proven,
-- and the 'gs' contains the 'commitments'.
hash :: SubGroup q => BS.ByteString -> [G q] -> E q
hash bs gs =
        let s = bs <> BS.intercalate (fromString ",") (bytesNat <$> gs) in
        let h = ByteArray.convert (Crypto.hashWith Crypto.SHA256 s) in
        fromNatural (BS.foldl' (\acc b -> acc`shiftL`3 + fromIntegral b) (0::Natural) h)

-- * Type 'E'
-- | An exponent of a (necessarily cyclic) 'SubGroup' of a 'PrimeField'.
-- The value is always in @[0..'groupOrder'-1]@.
newtype E q = E { unE :: F (P q) }
 deriving (Eq,Ord,Show,NFData)

instance SubGroup q => FromNatural (E q) where
        fromNatural i = E (F (abs (i `mod` unF (groupOrder @q))))
                where abs x | x < 0 = x + unF (groupOrder @q)
                            | otherwise = x
instance ToNatural (E q) where
        nat = unF . unE

instance (SubGroup q, Additive (F (P q))) => Additive (E q) where
        zero = E zero
        E (F x) + E (F y) = E (F ((x + y) `mod` unF (groupOrder @q)))
instance (SubGroup q, Negable (F (P q))) => Negable (E q) where
        neg (E (F x)) | x == 0 = zero
                      | otherwise = E (F (fromIntegral ( neg (toInteger x)
                                                       + toInteger (unF (groupOrder @q)) )))
instance (SubGroup q, Multiplicative (F (P q))) => Multiplicative (E q) where
        one = E one
        E (F x) * E (F y) = E (F ((x * y) `mod` unF (groupOrder @q)))
instance SubGroup q => Random.Random (E q) where
        randomR (E (F lo), E (F hi)) =
                first (E . F . fromIntegral) .
                Random.randomR
                 ( 0`max`toInteger lo
                 , toInteger hi`min`(toInteger (unF (groupOrder @q)) - 1) )
        random =
                first (E . F . fromIntegral) .
                Random.randomR (0, toInteger (unF (groupOrder @q)) - 1)
instance SubGroup q => Enum (E q) where
        toEnum = fromNatural . fromIntegral
        fromEnum = fromIntegral . nat
        enumFromTo lo hi = List.unfoldr
         (\i -> if i<=hi then Just (i, i+one) else Nothing) lo

infixr 8 ^
-- | @(b '^' e)@ returns the modular exponentiation of base 'b' by exponent 'e'.
(^) :: SubGroup q => G q -> E q -> G q
(^) b (E (F e))
 | e == zero = one
 | otherwise = t * (b*b) ^ E (F (e`shiftR`1))
        where
        t | testBit e 0 = b
                | otherwise   = one

-- | @('randomR' i)@ returns a random integer in @[0..i-1]@.
randomR ::
 Monad m =>
 Random.RandomGen r =>
 Random.Random i =>
 Negable i =>
 Multiplicative i =>
 i -> S.StateT r m i
randomR i = S.StateT $ return . Random.randomR (zero, i-one)

-- | @('random')@ returns a random integer
-- in the range determined by its type.
random ::
 Monad m =>
 Random.RandomGen r =>
 Random.Random i =>
 Negable i =>
 Multiplicative i =>
 S.StateT r m i
random = S.StateT $ return . Random.random

instance Random.Random Natural where
        randomR (mini,maxi) =
                first (fromIntegral::Integer -> Natural) .
                Random.randomR (fromIntegral mini, fromIntegral maxi)
        random = first (fromIntegral::Integer -> Natural) . Random.random

-- * Groups

-- * Type 'Params'
class SubGroup q => Params q where
        paramsName :: String
instance Params WeakParams where
        paramsName = "WeakParams"
instance Params BeleniosParams where
        paramsName = "BeleniosParams"

-- ** Type 'WeakParams'
-- | Weak parameters for debugging purposes only.
data WeakParams
instance PrimeField WeakParams where
        fieldCharac = 263
instance SubGroup WeakParams where
        type P WeakParams = WeakParams
        groupGen = G (F 2)
        groupOrder = F 131

-- ** Type 'BeleniosParams'
-- | Parameters used in Belenios.
-- A 2048-bit 'fieldCharac' of a 'PrimeField',
-- with a 256-bit 'groupOrder' for a 'Multiplicative' 'SubGroup'
-- generated by 'groupGen'.
data BeleniosParams
instance PrimeField BeleniosParams where
        fieldCharac = 20694785691422546401013643657505008064922989295751104097100884787057374219242717401922237254497684338129066633138078958404960054389636289796393038773905722803605973749427671376777618898589872735865049081167099310535867780980030790491654063777173764198678527273474476341835600035698305193144284561701911000786737307333564123971732897913240474578834468260652327974647951137672658693582180046317922073668860052627186363386088796882120769432366149491002923444346373222145884100586421050242120365433561201320481118852408731077014151666200162313177169372189248078507711827842317498073276598828825169183103125680162072880719
instance SubGroup BeleniosParams where
        type P BeleniosParams = BeleniosParams
        groupGen = G (F 2402352677501852209227687703532399932712287657378364916510075318787663274146353219320285676155269678799694668298749389095083896573425601900601068477164491735474137283104610458681314511781646755400527402889846139864532661215055797097162016168270312886432456663834863635782106154918419982534315189740658186868651151358576410138882215396016043228843603930989333662772848406593138406010231675095763777982665103606822406635076697764025346253773085133173495194248967754052573659049492477631475991575198775177711481490920456600205478127054728238140972518639858334115700568353695553423781475582491896050296680037745308460627)
        groupOrder = F 78571733251071885079927659812671450121821421258408794611510081919805623223441

-- * Conversions

-- ** Class 'FromNatural'
class FromNatural a where
        fromNatural :: Natural -> a

-- ** Class 'ToNatural'
class ToNatural a where
        nat :: a -> Natural

-- | @('bytesNat' x)@ returns the serialization of 'x'.
bytesNat :: ToNatural n => n -> BS.ByteString
bytesNat = fromString . show . nat