{-# LANGUAGE RecordWildCards, MultiWayIf, NamedFieldPuns  #-}
module Bulletproofs.MultiRangeProof.Verifier (
  verifyProof,
  verifyTPoly,
  verifyLRCommitment,
) where
import Protolude
import Data.Curve.Weierstrass.SECP256K1 (PA, Fr, mul, gen)
import Bulletproofs.RangeProof.Internal
import Bulletproofs.Utils
import Bulletproofs.InnerProductProof as IPP hiding (verifyProof)
import qualified Bulletproofs.InnerProductProof as IPP
verifyProof
  :: Integer     
  -> [PA]        
  -> RangeProof Fr PA
  
  -> Bool
verifyProof upperBound vCommits proof@RangeProof{..}
  = and
      [ verifyTPoly n vCommitsExp2 proof x y z
      , verifyLRCommitment n mExp2 proof x y z
      ]
  where
    x = shamirX aCommit sCommit t1Commit t2Commit y z
    y = shamirY aCommit sCommit
    z = shamirZ aCommit sCommit y
    n = logBase2 upperBound
    m = length vCommits
    
    vCommitsExp2 = vCommits ++ residueCommits
    residueCommits = replicate (2 ^ log2Ceil m - m) mempty
    mExp2 = fromIntegral $ length vCommitsExp2
verifyTPoly
  :: Integer         
  -> [PA]   
  -> RangeProof Fr PA
  
  -> Fr              
  -> Fr              
  -> Fr              
  -> Bool
verifyTPoly n vCommits proof@RangeProof{..} x y z
  = lhs == rhs
  where
    m = fromIntegral $ length vCommits
    lhs = commit t tBlinding
    rhs =
          sumExps ((*) (z ^ 2) <$> powerVector z m) vCommits
          <>
          (gen `mul` delta n m y z)
          <>
          (t1Commit `mul` x)
          <>
          (t2Commit `mul` (x ^ 2))
verifyLRCommitment
  :: Integer         
  -> Integer
  -> RangeProof Fr PA
  
  -> Fr              
  -> Fr              
  -> Fr              
  -> Bool
verifyLRCommitment n m proof@RangeProof{..} x y z
  = IPP.verifyProof
      nm
      IPP.InnerProductBase { bGs = gs, bHs = hs', bH = u }
      commitmentLR
      productProof
  where
    commitmentLR = computeLRCommitment n m aCommit sCommit t tBlinding mu x y z hs'
    hs' = zipWith (\yi hi-> hi `mul` recip yi) (powerVector y nm) hs
    uChallenge = shamirU tBlinding mu t
    u = gen `mul` uChallenge
    nm = n * m