{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
module Voting.Protocol.Tally where

import Control.DeepSeq (NFData)
import Control.Monad (Monad(..), mapM, unless)
import Control.Monad.Trans.Except (Except, ExceptT, throwE)
import Data.Eq (Eq(..))
import Data.Function (($))
import Data.Functor ((<$>))
import Data.Maybe (maybe)
import Data.Semigroup (Semigroup(..))
import Data.Text (Text)
import Data.Tuple (fst, uncurry)
import GHC.Generics (Generic)
import Numeric.Natural (Natural)
import Prelude (fromIntegral)
import Text.Show (Show(..))
import qualified Control.Monad.Trans.State.Strict as S
import qualified Data.ByteString as BS
import qualified Data.List as List
import qualified Data.Map.Strict as Map

import Voting.Protocol.Utils
import Voting.Protocol.Arithmetic
import Voting.Protocol.Credential
import Voting.Protocol.Election

-- * Type 'Tally'
data Tally q = Tally
 { tally_countMax :: Natural
   -- ^ The maximal number of supportive 'Opinion's that a choice can get,
   -- which is here the same as the number of 'Ballot's.
   --
   -- Used in 'proveTally' to decrypt the actual
   -- count of votes obtained by a choice,
   -- by precomputing all powers of 'groupGen's up to it.
 , tally_encByChoiceByQuest :: EncryptedTally q
   -- ^ 'Encryption' by 'Question' by 'Ballot'.
 , tally_decShareByTrustee :: [DecryptionShare q]
   -- ^ 'DecryptionShare' by trustee.
 , tally_countByChoiceByQuest :: [[Natural]]
   -- ^ The decrypted count of supportive 'Opinion's, by choice by 'Question'.
 } deriving (Eq,Show,Generic,NFData)

-- ** Type 'EncryptedTally'
-- | 'Encryption' by 'Choice' by 'Question'.
type EncryptedTally q = [[Encryption q]]

-- | @('encryptedTally' ballots)@
-- returns the sum of the 'Encryption's of the given @ballots@,
-- along with the number of 'Ballot's.
encryptedTally :: SubGroup q => [Ballot q] -> (EncryptedTally q, Natural)
encryptedTally ballots =
        ( List.foldr (\Ballot{..} ->
                List.zipWith (\Answer{..} ->
                        List.zipWith (+)
                         (fst <$> answer_opinions))
                 ballot_answers
         )
         (List.repeat (List.repeat zero))
         ballots
        , fromIntegral $ List.length ballots
        )

-- ** Type 'DecryptionShareCombinator'
type DecryptionShareCombinator q =
        [DecryptionShare q] -> Except ErrorDecryptionShare [[DecryptionFactor q]]

proveTally ::
 SubGroup q =>
 (EncryptedTally q, Natural) -> [DecryptionShare q] ->
 DecryptionShareCombinator q ->
 Except ErrorDecryptionShare (Tally q)
proveTally
 (tally_encByChoiceByQuest, tally_countMax)
 tally_decShareByTrustee
 decShareCombinator = do
        decFactorByChoiceByQuest <- decShareCombinator tally_decShareByTrustee
        dec <- isoZipWithM err
         (\encByChoice decFactorByChoice ->
                maybe err return $
                        isoZipWith (\Encryption{..} decFactor -> encryption_vault / decFactor)
                         encByChoice
                         decFactorByChoice)
         tally_encByChoiceByQuest
         decFactorByChoiceByQuest
        let logMap = Map.fromList $ List.zip groupGenPowers [0..tally_countMax]
        let log x =
                maybe (throwE $ ErrorDecryptionShare_InvalidMaxCount) return $
                Map.lookup x logMap
        tally_countByChoiceByQuest <- (log `mapM`)`mapM`dec
        return Tally{..}
        where err = throwE $ ErrorDecryptionShare_Invalid "proveTally"

verifyTally ::
 SubGroup q =>
 Tally q -> DecryptionShareCombinator q ->
 Except ErrorDecryptionShare ()
verifyTally Tally{..} decShareCombinator = do
        decFactorByChoiceByQuest <- decShareCombinator tally_decShareByTrustee
        isoZipWith3M_ (throwE $ ErrorDecryptionShare_Invalid "verifyTally")
         (isoZipWith3M_ (throwE $ ErrorDecryptionShare_Invalid "verifyTally")
                 (\Encryption{..} decFactor count -> do
                        let groupGenPowCount = encryption_vault / decFactor
                        unless (groupGenPowCount == groupGen ^ fromNatural count) $
                                throwE ErrorDecryptionShare_Wrong))
         tally_encByChoiceByQuest
         decFactorByChoiceByQuest
         tally_countByChoiceByQuest

-- ** Type 'DecryptionShare'
-- | A decryption share. It is computed by a trustee
-- from its 'SecretKey' share and the 'EncryptedTally',
-- and contains a cryptographic 'Proof' that it hasn't cheated.
data DecryptionShare q = DecryptionShare
 { decryptionShare_factors :: [[DecryptionFactor q]]
   -- ^ 'DecryptionFactor' by choice by 'Question'.
 , decryptionShare_proofs  :: [[Proof q]]
   -- ^ 'Proof's that 'decryptionShare_factors' were correctly computed.
 } deriving (Eq,Show,Generic,NFData)

-- *** Type 'DecryptionFactor'
-- | @'encryption_nonce' '^'trusteeSecKey@
type DecryptionFactor = G

-- @('proveDecryptionShare' encByChoiceByQuest trusteeSecKey)@
proveDecryptionShare ::
 Monad m => SubGroup q => RandomGen r =>
 EncryptedTally q -> SecretKey q -> S.StateT r m (DecryptionShare q)
proveDecryptionShare encByChoiceByQuest trusteeSecKey = do
        res <- (proveDecryptionFactor trusteeSecKey `mapM`) `mapM` encByChoiceByQuest
        return $ uncurry DecryptionShare $ List.unzip (List.unzip <$> res)

proveDecryptionFactor ::
 Monad m => SubGroup q => RandomGen r =>
 SecretKey q -> Encryption q -> S.StateT r m (DecryptionFactor q, Proof q)
proveDecryptionFactor trusteeSecKey Encryption{..} = do
        proof <- prove trusteeSecKey [groupGen, encryption_nonce] (hash zkp)
        return (encryption_nonce^trusteeSecKey, proof)
        where zkp = decryptionShareStatement (publicKey trusteeSecKey)

decryptionShareStatement :: SubGroup q => PublicKey q -> BS.ByteString
decryptionShareStatement pubKey =
        "decrypt|"<>bytesNat pubKey<>"|"

-- *** Type 'ErrorDecryptionShare'
data ErrorDecryptionShare
 =   ErrorDecryptionShare_Invalid Text
     -- ^ The number of 'DecryptionFactor's or
     -- the number of 'Proof's is not the same
     -- or not the expected number.
 |   ErrorDecryptionShare_Wrong
     -- ^ The 'Proof' of a 'DecryptionFactor' is wrong.
 |   ErrorDecryptionShare_InvalidMaxCount
 deriving (Eq,Show,Generic,NFData)

-- | @('verifyDecryptionShare' encTally trusteePubKey trusteeDecShare)@
-- checks that 'trusteeDecShare'
-- (supposedly submitted by a trustee whose 'PublicKey' is 'trusteePubKey')
-- is valid with respect to the 'EncryptedTally' 'encTally'.
verifyDecryptionShare ::
 Monad m => SubGroup q =>
 EncryptedTally q -> PublicKey q -> DecryptionShare q ->
 ExceptT ErrorDecryptionShare m ()
verifyDecryptionShare encTally trusteePubKey DecryptionShare{..} =
        let zkp = decryptionShareStatement trusteePubKey in
        isoZipWith3M_ (throwE $ ErrorDecryptionShare_Invalid "verifyDecryptionShare")
         (isoZipWith3M_ (throwE $ ErrorDecryptionShare_Invalid "verifyDecryptionShare") $
         \Encryption{..} decFactor proof ->
                unless (proof_challenge proof == hash zkp
                 [ commit proof groupGen trusteePubKey
                 , commit proof encryption_nonce decFactor
                 ]) $
                        throwE ErrorDecryptionShare_Wrong)
         encTally
         decryptionShare_factors
         decryptionShare_proofs

verifyDecryptionShareByTrustee ::
 Monad m => SubGroup q =>
 EncryptedTally q -> [PublicKey q] -> [DecryptionShare q] ->
 ExceptT ErrorDecryptionShare m ()
verifyDecryptionShareByTrustee encTally =
        isoZipWithM_ (throwE $ ErrorDecryptionShare_Invalid "verifyDecryptionShare")
         (verifyDecryptionShare encTally)