{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
module Protocol.Election where
import Control.Monad (Monad(..), join, mapM, zipWithM)
import Control.Monad.Morph (MFunctor(..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Data.Bool
import Data.Either (either)
import Data.Eq (Eq(..))
import Data.Foldable (Foldable, foldMap, and)
import Data.Function (($), id, const)
import Data.Functor (Functor, (<$>))
import Data.Functor.Identity (Identity(..))
import Data.Maybe (Maybe(..), fromMaybe)
import Data.Ord (Ord(..))
import Data.Semigroup (Semigroup(..))
import Data.Text (Text)
import Data.Traversable (Traversable(..))
import Data.Tuple (fst, snd, uncurry)
import GHC.Natural (minusNaturalMaybe)
import Numeric.Natural (Natural)
import Prelude (fromIntegral)
import Text.Show (Show(..))
import qualified Control.Monad.Trans.Except as Exn
import qualified Control.Monad.Trans.State.Strict as S
import qualified Data.ByteString as BS
import qualified Data.List as List
import Protocol.Arithmetic
import Protocol.Credential
data Encryption q = Encryption
{ encryption_nonce :: G q
, encryption_vault :: G q
} deriving (Eq,Show)
instance SubGroup q => Additive (Encryption q) where
zero = Encryption one one
x+y = Encryption
(encryption_nonce x * encryption_nonce y)
(encryption_vault x * encryption_vault y)
type EncryptionNonce = E
encrypt ::
Monad m => RandomGen r => SubGroup q =>
PublicKey q -> E q ->
S.StateT r m (EncryptionNonce q, Encryption q)
encrypt pubKey clear = do
encNonce <- random
return $ (encNonce,)
Encryption
{ encryption_nonce = groupGen^encNonce
, encryption_vault = pubKey ^encNonce * groupGen^clear
}
data Proof q = Proof
{ proof_challenge :: Challenge q
, proof_response :: E q
} deriving (Eq,Show)
newtype ZKP = ZKP BS.ByteString
type Challenge = E
type Oracle list q = list (Commitment q) -> Challenge q
prove ::
Monad m => RandomGen r => SubGroup q => Functor list =>
E q -> list (Commitment q) -> Oracle list q -> S.StateT r m (Proof q)
prove sec commitBases oracle = do
nonce <- random
let proof_challenge = oracle $ (^ nonce) <$> commitBases
return Proof
{ proof_challenge
, proof_response = nonce - sec*proof_challenge
}
type Commitment = G
commit :: SubGroup q => Proof q -> G q -> G q -> Commitment q
commit Proof{..} base basePowSec =
base^proof_response *
basePowSec^proof_challenge
{-# INLINE commit #-}
type Disjunction = G
booleanDisjunctions :: SubGroup q => [Disjunction q]
booleanDisjunctions = List.take 2 groupGenInverses
intervalDisjunctions :: SubGroup q => Opinion q -> Opinion q -> [Disjunction q]
intervalDisjunctions mini maxi =
List.genericTake (fromMaybe 0 $ (nat maxi + 1)`minusNaturalMaybe`nat mini) $
List.genericDrop (nat mini) $
groupGenInverses
type Opinion = E
newtype DisjProof q = DisjProof [Proof q]
deriving (Eq,Show)
proveEncryption ::
forall m r q.
Monad m => RandomGen r => SubGroup q =>
PublicKey q -> ZKP ->
([Disjunction q],[Disjunction q]) ->
(EncryptionNonce q, Encryption q) ->
S.StateT r m (DisjProof q)
proveEncryption elecPubKey voterZKP (prevDisjs,nextDisjs) (encNonce,enc) = do
prevFakes <- fakeProof `mapM` prevDisjs
nextFakes <- fakeProof `mapM` nextDisjs
let prevProofs = fst <$> prevFakes
let nextProofs = fst <$> nextFakes
let challengeSum =
sum (proof_challenge <$> prevProofs) +
sum (proof_challenge <$> nextProofs)
let statement = encryptionStatement voterZKP enc
correctProof <- prove encNonce [groupGen, elecPubKey] $
\correctCommitments ->
let commitments =
foldMap snd prevFakes <>
correctCommitments <>
foldMap snd nextFakes in
hash statement commitments - challengeSum
return $ DisjProof $ prevProofs <> (correctProof : nextProofs)
where
fakeProof :: Disjunction q -> S.StateT r m (Proof q, [Commitment q])
fakeProof disj = do
proof_challenge <- random
proof_response <- random
let proof = Proof{..}
return (proof, encryptionCommitments elecPubKey enc (disj, proof))
verifyEncryption ::
Monad m =>
SubGroup q =>
PublicKey q -> ZKP ->
[Disjunction q] ->
(Encryption q, DisjProof q) ->
Exn.ExceptT ErrorValidateEncryption m Bool
verifyEncryption elecPubKey voterZKP disjs (enc, DisjProof proofs)
| List.length proofs /= List.length disjs =
Exn.throwE $ ErrorValidateEncryption_InvalidProofLength
(fromIntegral $ List.length proofs)
(fromIntegral $ List.length disjs)
| otherwise = return $ challengeSum == hash (encryptionStatement voterZKP enc) commitments
where
challengeSum = sum (proof_challenge <$> proofs)
commitments = foldMap
(encryptionCommitments elecPubKey enc)
(List.zip disjs proofs)
encryptionStatement :: SubGroup q => ZKP -> Encryption q -> BS.ByteString
encryptionStatement (ZKP voterZKP) Encryption{..} =
"prove|"<>voterZKP<>"|"
<> bytesNat encryption_nonce<>","
<> bytesNat encryption_vault<>"|"
encryptionCommitments ::
SubGroup q =>
PublicKey q -> Encryption q ->
(Disjunction q, Proof q) -> [G q]
encryptionCommitments elecPubKey Encryption{..} (disj, proof) =
[ commit proof groupGen encryption_nonce
, commit proof elecPubKey (encryption_vault*disj)
]
data ErrorValidateEncryption
= ErrorValidateEncryption_InvalidProofLength Natural Natural
deriving (Eq,Show)
data Question q = Question
{ question_text :: Text
, question_choices :: [Text]
, question_mini :: Opinion q
, question_maxi :: Opinion q
} deriving (Eq, Show)
data Answer q = Answer
{ answer_opinions :: [(Encryption q, DisjProof q)]
, answer_sumProof :: DisjProof q
} deriving (Eq,Show)
encryptAnswer ::
Monad m => RandomGen r => SubGroup q =>
PublicKey q -> ZKP ->
Question q -> [Bool] ->
S.StateT r (Exn.ExceptT ErrorAnswer m) (Answer q)
encryptAnswer elecPubKey zkp Question{..} opinionsBools
| not (question_mini <= opinionsSum && opinionsSum <= question_maxi) =
lift $ Exn.throwE $
ErrorAnswer_WrongSumOfOpinions
(nat opinionsSum)
(nat question_mini)
(nat question_maxi)
| List.length opinions /= List.length question_choices =
lift $ Exn.throwE $
ErrorAnswer_WrongNumberOfOpinions
(fromIntegral $ List.length opinions)
(fromIntegral $ List.length question_choices)
| otherwise = do
encryptions <- encrypt elecPubKey `mapM` opinions
individualProofs <- zipWithM
(\opinion -> proveEncryption elecPubKey zkp $
if opinion
then ([booleanDisjunctions List.!!0],[])
else ([],[booleanDisjunctions List.!!1]))
opinionsBools encryptions
sumProof <- proveEncryption elecPubKey zkp
((List.tail <$>) $ List.genericSplitAt (nat (opinionsSum - question_mini)) $
intervalDisjunctions question_mini question_maxi)
( sum (fst <$> encryptions)
, sum (snd <$> encryptions)
)
return $ Answer
{ answer_opinions = List.zip
(snd <$> encryptions)
individualProofs
, answer_sumProof = sumProof
}
where
opinionsSum = sum opinions
opinions = (\o -> if o then one else zero) <$> opinionsBools
verifyAnswer ::
SubGroup q =>
PublicKey q -> ZKP ->
Question q -> Answer q -> Bool
verifyAnswer elecPubKey zkp Question{..} Answer{..}
| List.length question_choices /= List.length answer_opinions = False
| otherwise = either (const False) id $ Exn.runExcept $ do
validOpinions <-
verifyEncryption elecPubKey zkp booleanDisjunctions
`traverse` answer_opinions
validSum <- verifyEncryption elecPubKey zkp
(intervalDisjunctions question_mini question_maxi)
( sum (fst <$> answer_opinions)
, answer_sumProof )
return (and validOpinions && validSum)
data ErrorAnswer
= ErrorAnswer_WrongNumberOfOpinions Natural Natural
| ErrorAnswer_WrongSumOfOpinions Natural Natural Natural
deriving (Eq,Show)
data Election q = Election
{ election_name :: Text
, election_description :: Text
, election_publicKey :: PublicKey q
, election_questions :: [Question q]
, election_uuid :: UUID
, election_hash :: Hash
} deriving (Eq,Show)
newtype Hash = Hash Text
deriving (Eq,Ord,Show)
data Ballot q = Ballot
{ ballot_answers :: [Answer q]
, ballot_signature :: Maybe (Signature q)
, ballot_election_uuid :: UUID
, ballot_election_hash :: Hash
}
encryptBallot ::
Monad m => RandomGen r => SubGroup q =>
Election q -> Maybe (SecretKey q) -> [[Bool]] ->
S.StateT r (Exn.ExceptT ErrorBallot m) (Ballot q)
encryptBallot Election{..} secKeyMay opinionsByQuest
| List.length election_questions /= List.length opinionsByQuest =
lift $ Exn.throwE $
ErrorBallot_WrongNumberOfAnswers
(fromIntegral $ List.length opinionsByQuest)
(fromIntegral $ List.length election_questions)
| otherwise = do
let (voterKeys, voterZKP) =
case secKeyMay of
Nothing -> (Nothing, ZKP "")
Just secKey ->
( Just (secKey, pubKey)
, ZKP (bytesNat pubKey) )
where pubKey = publicKey secKey
ballot_answers <-
hoist (Exn.withExceptT ErrorBallot_Answer) $
zipWithM (encryptAnswer election_publicKey voterZKP)
election_questions opinionsByQuest
ballot_signature <- case voterKeys of
Nothing -> return Nothing
Just (secKey, signature_publicKey) -> do
signature_proof <-
prove secKey (Identity groupGen) $
\(Identity commitment) ->
hash
(signatureCommitments voterZKP commitment)
(signatureStatement ballot_answers)
return $ Just Signature{..}
return Ballot
{ ballot_answers
, ballot_election_hash = election_hash
, ballot_election_uuid = election_uuid
, ballot_signature
}
verifyBallot :: SubGroup q => Election q -> Ballot q -> Bool
verifyBallot Election{..} Ballot{..} =
ballot_election_uuid == election_uuid &&
ballot_election_hash == election_hash &&
List.length election_questions == List.length ballot_answers &&
let (isValidSign, zkpSign) =
case ballot_signature of
Nothing -> (True, ZKP "")
Just Signature{..} ->
let zkp = ZKP (bytesNat signature_publicKey) in
(, zkp) $
proof_challenge signature_proof == hash
(signatureCommitments zkp (commit signature_proof groupGen signature_publicKey))
(signatureStatement ballot_answers)
in
and $ isValidSign :
List.zipWith (verifyAnswer election_publicKey zkpSign)
election_questions ballot_answers
data Signature q = Signature
{ signature_publicKey :: PublicKey q
, signature_proof :: Proof q
}
signatureStatement :: Foldable f => SubGroup q => f (Answer q) -> [G q]
signatureStatement =
foldMap $ \Answer{..} ->
(`foldMap` answer_opinions) $ \(Encryption{..}, _proof) ->
[encryption_nonce, encryption_vault]
signatureCommitments :: SubGroup q => ZKP -> Commitment q -> BS.ByteString
signatureCommitments (ZKP voterZKP) commitment =
"sig|"<>voterZKP<>"|"<>bytesNat commitment<>"|"
data ErrorBallot
= ErrorBallot_WrongNumberOfAnswers Natural Natural
| ErrorBallot_Answer ErrorAnswer
deriving (Eq,Show)
data DecryptionShare q = DecryptionShare
{ decryptionShare_factors :: [[DecryptionFactor q]]
, decryptionShare_proofs :: [[Proof q]]
} deriving (Eq,Show)
computeDecryptionShare ::
Monad m => SubGroup q => RandomGen r =>
SecretKey q -> [[Encryption q]] -> S.StateT r m (DecryptionShare q)
computeDecryptionShare secKey encs = do
res <- mapM (mapM (decryptionFactor secKey)) encs
return $ uncurry DecryptionShare $ List.unzip (List.unzip <$> res)
decryptionFactor ::
Monad m => SubGroup q => RandomGen r =>
SecretKey q -> Encryption q -> S.StateT r m (DecryptionFactor q, Proof q)
decryptionFactor secKey Encryption{..} = do
proof <- prove secKey [groupGen, encryption_nonce] (hash zkp)
return (encryption_nonce^secKey, proof)
where zkp = decryptionStatement (publicKey secKey)
decryptionStatement :: SubGroup q => PublicKey q -> BS.ByteString
decryptionStatement pubKey =
"decrypt|"<>bytesNat pubKey<>"|"
type DecryptionFactor = G
data ErrorDecryptionShare
= ErrorDecryptionShare_Invalid
deriving (Eq,Show)
checkDecryptionShare ::
Monad m => SubGroup q => RandomGen r =>
[[Encryption q]] -> PublicKey q -> DecryptionShare q ->
Exn.ExceptT ErrorDecryptionShare m Bool
checkDecryptionShare encTally pubKey DecryptionShare{..}
| len <- List.length encTally
, len == List.length decryptionShare_factors
, len == List.length decryptionShare_proofs =
Exn.throwE ErrorDecryptionShare_Invalid
| otherwise =
return $ and $ join $ List.zipWith3 (List.zipWith3
(\encFactor proof Encryption{..} ->
hash zkp
[ commit proof groupGen pubKey
, commit proof encryption_nonce encFactor
] == proof_challenge proof
)) decryptionShare_factors decryptionShare_proofs encTally
where zkp = decryptionStatement pubKey