{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Protocol.ARK.Protostar.Lookup where

import           Data.Map                                        (fromList, mapWithKey)
import           Data.These                                      (These (..))
import           Data.Zip
import           Numeric.Natural                                 (Natural)
import           Prelude                                         hiding (Num (..), repeat, sum, zip, zipWith, (!!), (/),
                                                                  (^))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field                 (Zp)
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.Polynomials.Multivariate    (Polynomial')
import           ZkFold.Base.Data.Sparse.Vector                  (SVector (..))
import           ZkFold.Base.Data.Vector                         (Vector)
import           ZkFold.Base.Protocol.ARK.Protostar.SpecialSound (SpecialSoundProtocol (..), SpecialSoundTranscript)
import           ZkFold.Symbolic.Compiler                        (Arithmetic)

data ProtostarLookup (l :: Natural) (sizeT :: Natural)

data ProtostarLookupParams f sizeT = ProtostarLookupParams (Zp sizeT -> f) (f -> [Zp sizeT])

instance (Arithmetic f, KnownNat sizeT) => SpecialSoundProtocol f (ProtostarLookup l sizeT) where
    type Witness f (ProtostarLookup l sizeT)         = Vector l f
    -- ^ w in the paper
    type Input f (ProtostarLookup l sizeT)           = ProtostarLookupParams f sizeT
    -- ^ t and t^{-1} from the paper
    type ProverMessage t (ProtostarLookup l sizeT)   = (Vector l t, SVector sizeT t)
    -- ^ (w, m) or (h, g) in the paper
    type VerifierMessage t (ProtostarLookup l sizeT) = t

    type Dimension (ProtostarLookup l sizeT)         = l + sizeT + 1
    type Degree (ProtostarLookup l sizeT)            = 2

    rounds :: ProtostarLookup l sizeT -> Natural
    rounds :: ProtostarLookup l sizeT -> Natural
rounds ProtostarLookup l sizeT
_ = Natural
2

    prover :: ProtostarLookup l sizeT
           -> Witness f (ProtostarLookup l sizeT)
           -> Input f (ProtostarLookup l sizeT)
           -> SpecialSoundTranscript f (ProtostarLookup l sizeT)
           -> ProverMessage f (ProtostarLookup l sizeT)
    prover :: ProtostarLookup l sizeT
-> Witness f (ProtostarLookup l sizeT)
-> Input f (ProtostarLookup l sizeT)
-> SpecialSoundTranscript f (ProtostarLookup l sizeT)
-> ProverMessage f (ProtostarLookup l sizeT)
prover ProtostarLookup l sizeT
_ Witness f (ProtostarLookup l sizeT)
w (ProtostarLookupParams Zp sizeT -> f
_ f -> [Zp sizeT]
invT) [] =
        let m :: SVector sizeT f
m      = Vector l (SVector sizeT f) -> SVector sizeT f
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum (Map (Zp sizeT) f -> SVector sizeT f
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp sizeT) f -> SVector sizeT f)
-> (f -> Map (Zp sizeT) f) -> f -> SVector sizeT f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Zp sizeT, f)] -> Map (Zp sizeT) f
forall k a. Ord k => [(k, a)] -> Map k a
fromList ([(Zp sizeT, f)] -> Map (Zp sizeT) f)
-> (f -> [(Zp sizeT, f)]) -> f -> Map (Zp sizeT) f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Zp sizeT] -> [f] -> [(Zp sizeT, f)]
forall a b. [a] -> [b] -> [(a, b)]
forall (f :: Type -> Type) a b. Zip f => f a -> f b -> f (a, b)
`zip` f -> [f]
forall a. a -> [a]
forall (f :: Type -> Type) a. Repeat f => a -> f a
repeat f
forall a. MultiplicativeMonoid a => a
one) ([Zp sizeT] -> [(Zp sizeT, f)])
-> (f -> [Zp sizeT]) -> f -> [(Zp sizeT, f)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f -> [Zp sizeT]
invT (f -> SVector sizeT f) -> Vector l f -> Vector l (SVector sizeT f)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector l f
Witness f (ProtostarLookup l sizeT)
w)
        in (Vector l f
Witness f (ProtostarLookup l sizeT)
w, SVector sizeT f
m)
    prover ProtostarLookup l sizeT
_ Witness f (ProtostarLookup l sizeT)
_ (ProtostarLookupParams Zp sizeT -> f
t f -> [Zp sizeT]
_) [((Vector l f
w, SVector sizeT f
m), VerifierMessage f (ProtostarLookup l sizeT)
r)] =
        let h :: Vector l f
h      = (f -> f) -> Vector l f -> Vector l f
forall a b. (a -> b) -> Vector l a -> Vector l b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (\f
w_i -> f
forall a. MultiplicativeMonoid a => a
one f -> f -> f
forall a. Field a => a -> a -> a
// (f
w_i f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
+ f
VerifierMessage f (ProtostarLookup l sizeT)
r)) Vector l f
w
            g :: SVector sizeT f
g      = Map (Zp sizeT) f -> SVector sizeT f
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp sizeT) f -> SVector sizeT f)
-> Map (Zp sizeT) f -> SVector sizeT f
forall a b. (a -> b) -> a -> b
$ (Zp sizeT -> f -> f) -> Map (Zp sizeT) f -> Map (Zp sizeT) f
forall k a b. (k -> a -> b) -> Map k a -> Map k b
mapWithKey (\Zp sizeT
i f
m_i -> f
m_i f -> f -> f
forall a. Field a => a -> a -> a
// (Zp sizeT -> f
t Zp sizeT
i f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
+ f
VerifierMessage f (ProtostarLookup l sizeT)
r)) (Map (Zp sizeT) f -> Map (Zp sizeT) f)
-> Map (Zp sizeT) f -> Map (Zp sizeT) f
forall a b. (a -> b) -> a -> b
$ SVector sizeT f -> Map (Zp sizeT) f
forall (size :: Natural) a. SVector size a -> Map (Zp size) a
fromSVector SVector sizeT f
m
        in (Vector l f
h, SVector sizeT f
g)
    prover ProtostarLookup l sizeT
_ Witness f (ProtostarLookup l sizeT)
_ Input f (ProtostarLookup l sizeT)
_ SpecialSoundTranscript f (ProtostarLookup l sizeT)
_ = [Char] -> (Vector l f, SVector sizeT f)
forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid transcript"

    -- TODO: implement this
    verifier' :: ProtostarLookup l sizeT
              -> Input f (ProtostarLookup l sizeT)
              -> SpecialSoundTranscript Natural (ProtostarLookup l sizeT)
              -> Vector (Dimension (ProtostarLookup l sizeT)) (Polynomial' f)
    verifier' :: ProtostarLookup l sizeT
-> Input f (ProtostarLookup l sizeT)
-> SpecialSoundTranscript Natural (ProtostarLookup l sizeT)
-> Vector (Dimension (ProtostarLookup l sizeT)) (Polynomial' f)
verifier' = ProtostarLookup l sizeT
-> Input f (ProtostarLookup l sizeT)
-> SpecialSoundTranscript Natural (ProtostarLookup l sizeT)
-> Vector (Dimension (ProtostarLookup l sizeT)) (Polynomial' f)
ProtostarLookup l sizeT
-> ProtostarLookupParams f sizeT
-> [((Vector l Natural, SVector sizeT Natural), Natural)]
-> Vector ((l + sizeT) + 1) (Polynomial' f)
forall a. HasCallStack => a
undefined

    verifier :: ProtostarLookup l sizeT
             -> Input f (ProtostarLookup l sizeT)
             -> SpecialSoundTranscript f (ProtostarLookup l sizeT)
             -> Bool
    verifier :: ProtostarLookup l sizeT
-> Input f (ProtostarLookup l sizeT)
-> SpecialSoundTranscript f (ProtostarLookup l sizeT)
-> Bool
verifier ProtostarLookup l sizeT
_ (ProtostarLookupParams Zp sizeT -> f
t f -> [Zp sizeT]
_) [((Vector l f
w, SVector sizeT f
m), VerifierMessage f (ProtostarLookup l sizeT)
r), ((Vector l f
h, SVector sizeT f
g), VerifierMessage f (ProtostarLookup l sizeT)
_)] =
        let c1 :: Bool
c1 = Vector l f -> f
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum Vector l f
h f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== SVector sizeT f -> f
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum SVector sizeT f
g
            c2 :: Bool
c2 = (f -> Bool) -> Vector l f -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
forall a. MultiplicativeMonoid a => a
one) (Vector l f -> Bool) -> Vector l f -> Bool
forall a b. (a -> b) -> a -> b
$ (f -> f -> f) -> Vector l f -> Vector l f -> Vector l f
forall a b c.
(a -> b -> c) -> Vector l a -> Vector l b -> Vector l c
forall (f :: Type -> Type) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
zipWith f -> f -> f
forall a. MultiplicativeSemigroup a => a -> a -> a
(*) Vector l f
h ((f -> f) -> Vector l f -> Vector l f
forall a b. (a -> b) -> Vector l a -> Vector l b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (VerifierMessage f (ProtostarLookup l sizeT)
-> VerifierMessage f (ProtostarLookup l sizeT)
-> VerifierMessage f (ProtostarLookup l sizeT)
forall a. AdditiveSemigroup a => a -> a -> a
+VerifierMessage f (ProtostarLookup l sizeT)
r) Vector l f
w)
            g' :: SVector sizeT f
g' = Map (Zp sizeT) f -> SVector sizeT f
forall (size :: Natural) a. Map (Zp size) a -> SVector size a
SVector (Map (Zp sizeT) f -> SVector sizeT f)
-> Map (Zp sizeT) f -> SVector sizeT f
forall a b. (a -> b) -> a -> b
$ (Zp sizeT -> f -> f) -> Map (Zp sizeT) f -> Map (Zp sizeT) f
forall k a b. (k -> a -> b) -> Map k a -> Map k b
mapWithKey (\Zp sizeT
i f
g_i -> f
g_i f -> f -> f
forall a. MultiplicativeSemigroup a => a -> a -> a
* (Zp sizeT -> f
t Zp sizeT
i f -> f -> f
forall a. AdditiveSemigroup a => a -> a -> a
+ f
VerifierMessage f (ProtostarLookup l sizeT)
r)) (Map (Zp sizeT) f -> Map (Zp sizeT) f)
-> Map (Zp sizeT) f -> Map (Zp sizeT) f
forall a b. (a -> b) -> a -> b
$ SVector sizeT f -> Map (Zp sizeT) f
forall (size :: Natural) a. SVector size a -> Map (Zp size) a
fromSVector SVector sizeT f
m
            f :: These f f -> Bool
f  = \case
                This f
_ -> Bool
False
                That f
_ -> Bool
False
                These f
x f
y -> f
x f -> f -> Bool
forall a. Eq a => a -> a -> Bool
== f
y
            c3 :: Bool
c3 = (Bool -> Bool) -> SVector sizeT Bool -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
all (Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
forall a. MultiplicativeMonoid a => a
one) (SVector sizeT Bool -> Bool) -> SVector sizeT Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (These f f -> Bool)
-> SVector sizeT f -> SVector sizeT f -> SVector sizeT Bool
forall a b c.
(These a b -> c)
-> SVector sizeT a -> SVector sizeT b -> SVector sizeT c
forall (f :: Type -> Type) a b c.
Semialign f =>
(These a b -> c) -> f a -> f b -> f c
alignWith These f f -> Bool
f SVector sizeT f
g' SVector sizeT f
m
        in Bool
c1 Bool -> Bool -> Bool
&& Bool
c2 Bool -> Bool -> Bool
&& Bool
c3
    verifier ProtostarLookup l sizeT
_ Input f (ProtostarLookup l sizeT)
_ SpecialSoundTranscript f (ProtostarLookup l sizeT)
_ = [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid transcript"