module Bulletproofs.Utils where

import Protolude

import qualified Crypto.PubKey.ECC.Prim as Crypto
import qualified Crypto.PubKey.ECC.Types as Crypto
import Crypto.Number.Generate (generateMax)
import PrimeField (PrimeField, toInt)

import Bulletproofs.Fq as Fq hiding (asInteger)
import Bulletproofs.Curve

-- | Return a vector containing the first n powers of a
powerVector :: (Eq f, Num f) => f -> Integer -> [f]
powerVector a x
= (\i -> if i == 0 && a == 0 then 0 else a ^ i) <\$> [0..x-1]

-- | Hadamard product or entry wise multiplication of two vectors
hadamardp :: Num a => [a] -> [a] -> [a]
hadamardp a b | length a == length b = zipWith (*) a b
| otherwise = panic "Vector sizes must match"

dot :: Num a => [a] -> [a] -> a
dot xs ys = sum \$ hadamardp xs ys

(^+^) :: Num a => [a] -> [a] -> [a]
(^+^) = zipWith (+)

(^-^) :: Num a => [a] -> [a] -> [a]
(^-^) = zipWith (-)

-- | Add two points of the same curve
addP :: Crypto.Point -> Crypto.Point -> Crypto.Point

-- | Substract two points of the same curve
subP :: Crypto.Point -> Crypto.Point -> Crypto.Point
subP x y = Crypto.pointAdd curve x (Crypto.pointNegate curve y)

-- | Multiply a scalar and a point in an elliptic curve
mulP :: PrimeField p -> Crypto.Point -> Crypto.Point
mulP x = Crypto.pointMul curve (toInt x)

-- | Double exponentiation (Shamir's trick): g0^x0 + g1^x1
addTwoMulP :: PrimeField p -> Crypto.Point -> PrimeField p -> Crypto.Point -> Crypto.Point
addTwoMulP exp0 pt0 exp1 pt1 = Crypto.pointAddTwoMuls curve (toInt exp0) pt0 (toInt exp1) pt1

-- | Raise every point to the corresponding exponent, sum up results
sumExps :: [PrimeField p] -> [Crypto.Point] -> Crypto.Point
sumExps (exp0:exp1:exps) (pt0:pt1:pts)
= addTwoMulP exp0 pt0 exp1 pt1 `addP` sumExps exps pts
sumExps (exp:_) (pt:_) = mulP exp pt -- this also catches cases where either list is longer than the other
sumExps _ _ = Crypto.PointO  -- this catches cases where either list is empty

-- | Create a Pedersen commitment to a value given
-- a value and a blinding factor
commit :: PrimeField p -> PrimeField p -> Crypto.Point
commit x r = addTwoMulP x g r h

isLogBase2 :: Integer -> Bool
isLogBase2 x
| x == 1 = True
| x == 0 || (x `mod` 2 /= 0) = False
| otherwise = isLogBase2 (x `div` 2)

logBase2 :: Integer -> Integer
logBase2 = floor . logBase 2.0 . fromIntegral

logBase2M :: Integer -> Maybe Integer
logBase2M x
= if isLogBase2 x
then Just (logBase2 x)
else Nothing

slice :: Integer -> Integer -> [a] -> [a]
slice n j vs = take (fromIntegral \$ j  * n - (j - 1)*n) (drop (fromIntegral \$ (j - 1) * n) vs)

-- | Append minimal amount of zeroes until the list has a length which
-- is a power of two.
:: Num f => [f] -> [f]
padToNearestPowerOfTwo [] = []
padToNearestPowerOfTwo xs = padToNearestPowerOfTwoOf (length xs) xs

-- | Given n, append zeroes until the list has length 2^n.
:: Num f
=> Int -- ^ n
-> [f] -- ^ list which should have length <= 2^n
-> [f] -- ^ list which will have length 2^n
padToNearestPowerOfTwoOf i xs = xs ++ replicate padLength 0
where
padLength = nearestPowerOfTwo - length xs
nearestPowerOfTwo = 2 ^ log2Ceil i

-- | Calculate ceiling of log base 2 of an integer.
log2Ceil :: Int -> Int
log2Ceil x = floorLog + correction
where
floorLog = finiteBitSize x - 1 - countLeadingZeros x
correction = if countTrailingZeros x < floorLog
then 1
else 0

randomN :: MonadRandom m => Integer -> m Integer
randomN n = generateMax (2^n)

chooseBlindingVectors :: (Num f, MonadRandom m) => Integer -> m ([f], [f])
chooseBlindingVectors n = do
sL <- replicateM (fromInteger n) (fromInteger <\$> generateMax (2^n))
sR <- replicateM (fromInteger n) (fromInteger <\$> generateMax (2^n))
pure (sL, sR)

--------------------------------------------------
-- Fiat-Shamir transformations
--------------------------------------------------

shamirY :: Num f => Crypto.Point -> Crypto.Point -> f
shamirY aCommit sCommit
= fromInteger \$ oracle \$
show _q <> pointToBS aCommit <> pointToBS sCommit

shamirZ :: (Show f, Num f) => Crypto.Point -> Crypto.Point -> f -> f
shamirZ aCommit sCommit y
= fromInteger \$ oracle \$
show _q <> pointToBS aCommit <> pointToBS sCommit <> show y

shamirX
:: (Show f, Num f)
=> Crypto.Point
-> Crypto.Point
-> Crypto.Point
-> Crypto.Point
-> f
-> f
-> f
shamirX aCommit sCommit t1Commit t2Commit y z
= fromInteger \$ oracle \$
show _q <> pointToBS aCommit <> pointToBS sCommit <> pointToBS t1Commit <> pointToBS t2Commit <> show y <> show z

shamirX'
:: Num f
=> Crypto.Point
-> Crypto.Point
-> Crypto.Point
-> f
shamirX' commitmentLR l' r'
= fromInteger \$ oracle \$
show _q <> pointToBS l' <> pointToBS r' <> pointToBS commitmentLR

shamirU :: (Show f, Num f) => f -> f -> f -> f
shamirU tBlinding mu t
= fromInteger \$ oracle \$
show _q <> show tBlinding <> show mu <> show t