{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TypeApplications   #-}

{-# OPTIONS_GHC -Wno-orphans    #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint (
    ClosedPoly,
    MonadBlueprint (..),
    NewConstraint,
    Witness,
    WitnessField,
    circuit,
    circuits,
) where

import           Control.Monad.Identity                              (Identity (..))
import           Control.Monad.State                                 (State, gets, modify, runState)
import           Data.Functor                                        (($>))
import           Data.Map                                            ((!))
import           Data.Set                                            (Set)
import qualified Data.Set                                            as Set
import           Numeric.Natural                                     (Natural)
import           Prelude                                             hiding (Bool (..), Eq (..), replicate, (*), (+),
                                                                      (-))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Sources
import           ZkFold.Base.Algebra.Polynomials.Multivariate        (var)
import qualified ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal as I
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal hiding (constraint)
import           ZkFold.Symbolic.Data.Bool                           (Bool (..))
import           ZkFold.Symbolic.Data.Conditional                    (Conditional (..))
import           ZkFold.Symbolic.Data.Eq                             (Eq (..))

type WitnessField a x = (Algebra a x, FiniteField x, BinaryExpansion x,
    Eq (Bool x) x, Conditional (Bool x) x, Conditional (Bool x) (Bool x))
-- ^ DSL for constructing witnesses in an arithmetic circuit. @a@ is a base
-- field; @x@ is a "field of witnesses over @a@" which you can safely assume to
-- be identical to @a@ with internalized equality.

type Witness i a = forall x . WitnessField a x => (i -> x) -> x
-- ^ A type of witness builders. @i@ is a type of variables, @a@ is a base field.
--
-- A function is a witness builer if, given an arbitrary field of witnesses @x@
-- over @a@ and a function mapping known variables to their witnesses, it computes
-- the new witness in @x@.
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.

type NewConstraint i a = forall x . Algebra a x => (i -> x) -> i -> x
-- ^ A type of constraints for new variables. @i@ is a type of variables, @a@ is a base field.
--
-- A function is a constraint for a new variable if, given an arbitrary algebra
-- @x@ over @a@, a function mapping known variables to their witnesses in that
-- algebra and a new variable, it computes the value of a constraint polynomial
-- in that algebra.
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.

type ClosedPoly i a = forall x . Algebra a x => (i -> x) -> x
-- ^ A type of polynomial expressions. @i@ is a type of variables, @a@ is a base field.
--
-- A function is a polynomial expression if, given an arbitrary algebra @x@ over
-- @a@ and a function mapping known variables to their witnesses, it computes a
-- new value in that algebra.
--
-- NOTE: the property above is correct by construction for each function of a
-- suitable type, you don't have to check it yourself.

class Monad m => MonadBlueprint i a m | m -> i, m -> a where
    -- ^ DSL for constructing arithmetic circuits. @i@ is a type of variables,
    -- @a@ is a base field and @m@ is a monad for constructing the circuit.
    --
    -- DSL provides the following guarantees:
    -- * There are no unconstrained variables;
    -- * Variables with equal constraints and witnesses are reused as much as possible;
    -- * Variables with either different constraints or different witnesses are different;
    -- * There is an order in which witnesses can be generated;
    -- * Constraints never reference undefined variables.
    --
    -- However, DSL does NOT provide the following guarantees (yet):
    -- * That provided witnesses satisfy the provided constraints. To check this,
    --   you can use 'ZkFold.Symbolic.Compiler.ArithmeticCircuit.checkCircuit'.
    -- * That introduced polynomial constraints are supported by the zk-SNARK
    --   utilized for later proving.

    -- | Creates new input variable.
    input :: m i

    -- | Returns a circuit with supplied variable as output.
    output :: i -> m (ArithmeticCircuit a)

    -- | Adds the supplied circuit to the blueprint and returns its output variable.
    runCircuit :: ArithmeticCircuit a -> m i

    -- | Creates new variable given a constraint polynomial and a witness.
    newConstrained :: NewConstraint i a -> Witness i a -> m i

    -- | Adds new constraint to the system.
    constraint :: ClosedPoly i a -> m ()

    -- | Creates new variable given a polynomial witness.
    newAssigned :: ClosedPoly i a -> m i
    newAssigned ClosedPoly i a
p = NewConstraint i a -> Witness i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
NewConstraint i a -> Witness i a -> m i
newConstrained (\i -> x
x i
i -> (i -> x) -> x
ClosedPoly i a
p i -> x
x x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
i) (i -> x) -> x
ClosedPoly i a
Witness i a
p

instance Arithmetic a => MonadBlueprint Natural a (State (ArithmeticCircuit a)) where
    input :: State (ArithmeticCircuit a) Natural
input = ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acOutput (ArithmeticCircuit a -> Natural)
-> State (ArithmeticCircuit a) (ArithmeticCircuit a)
-> State (ArithmeticCircuit a) Natural
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> State (ArithmeticCircuit a) (ArithmeticCircuit a)
forall a. State (ArithmeticCircuit a) (ArithmeticCircuit a)
I.input

    output :: Natural -> State (ArithmeticCircuit a) (ArithmeticCircuit a)
output Natural
i = (ArithmeticCircuit a -> ArithmeticCircuit a)
-> State (ArithmeticCircuit a) (ArithmeticCircuit a)
forall s (m :: Type -> Type) a. MonadState s m => (s -> a) -> m a
gets (\ArithmeticCircuit a
r -> ArithmeticCircuit a
r { acOutput = i })

    runCircuit :: ArithmeticCircuit a -> State (ArithmeticCircuit a) Natural
runCircuit ArithmeticCircuit a
r = (ArithmeticCircuit a -> ArithmeticCircuit a)
-> State (ArithmeticCircuit a) ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify (ArithmeticCircuit a -> ArithmeticCircuit a -> ArithmeticCircuit a
forall a. Semigroup a => a -> a -> a
<> ArithmeticCircuit a
r) State (ArithmeticCircuit a) ()
-> Natural -> State (ArithmeticCircuit a) Natural
forall (f :: Type -> Type) a b. Functor f => f a -> b -> f b
$> ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acOutput ArithmeticCircuit a
r

    newConstrained
        :: NewConstraint Natural a
        -> Witness Natural a
        -> State (ArithmeticCircuit a) Natural
    newConstrained :: NewConstraint Natural a
-> Witness Natural a -> State (ArithmeticCircuit a) Natural
newConstrained NewConstraint Natural a
new Witness Natural a
witness = do
        let ws :: Set Natural
ws = forall a i. (FiniteField a, Ord i) => Witness i a -> Set i
sources @a (Natural -> x) -> x
Witness Natural a
witness
            -- | We need a throwaway variable to feed into `new` which definitely would not be present in a witness
            x :: Natural
x = Set Natural -> Natural
forall a. Ord a => Set a -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
maximum ((Natural -> Natural) -> Set Natural -> Set Natural
forall a b. (a -> b) -> Set a -> Set b
Set.mapMonotonic (Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+Natural
1) Set Natural
ws Set Natural -> Set Natural -> Set Natural
forall a. Semigroup a => a -> a -> a
<> Natural -> Set Natural
forall a. a -> Set a
Set.singleton Natural
0)
            -- | `s` is meant to be a set of variables used in a witness not present in a constraint.
            s :: Set Natural
s = Set Natural
ws Set Natural -> Set Natural -> Set Natural
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` forall a i. (FiniteField a, Ord i) => Witness i a -> Set i
sources @a ((Natural -> x) -> Natural -> x
NewConstraint Natural a
`new` Natural
x)
        Natural
i <- Natural -> State (ArithmeticCircuit a) Natural
forall a. Natural -> State (ArithmeticCircuit a) Natural
addVariable (Natural -> State (ArithmeticCircuit a) Natural)
-> State (ArithmeticCircuit a) Natural
-> State (ArithmeticCircuit a) Natural
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Natural]
-> (Natural -> Constraint a) -> State (ArithmeticCircuit a) Natural
forall a.
Arithmetic a =>
[Natural]
-> (Natural -> Constraint a) -> State (ArithmeticCircuit a) Natural
newVariableWithSource (Set Natural -> [Natural]
forall a. Set a -> [a]
Set.toList Set Natural
s) ((Natural -> Constraint a) -> Natural -> Constraint a
NewConstraint Natural a
new Natural -> Constraint a
forall c i j.
Polynomial c i j =>
i -> P c i j (Map i j) [(c, M i j (Map i j))]
var)
        ClosedPoly Natural a -> State (ArithmeticCircuit a) ()
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m ()
constraint ((Natural -> x) -> Natural -> x
NewConstraint Natural a
`new` Natural
i)
        (Map Natural a -> a) -> State (ArithmeticCircuit a) ()
forall a. (Map Natural a -> a) -> State (ArithmeticCircuit a) ()
assignment (\Map Natural a
m -> (Natural -> a) -> a
Witness Natural a
witness (Map Natural a
m Map Natural a -> Natural -> a
forall k a. Ord k => Map k a -> k -> a
!))
        Natural -> State (ArithmeticCircuit a) Natural
forall a. a -> StateT (ArithmeticCircuit a) Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Natural
i

    constraint :: ClosedPoly Natural a -> State (ArithmeticCircuit a) ()
constraint ClosedPoly Natural a
p = Constraint a -> State (ArithmeticCircuit a) ()
forall a.
Arithmetic a =>
Constraint a -> State (ArithmeticCircuit a) ()
I.constraint ((Natural -> Constraint a) -> Constraint a
ClosedPoly Natural a
p Natural -> Constraint a
forall c i j.
Polynomial c i j =>
i -> P c i j (Map i j) [(c, M i j (Map i j))]
var)

circuit :: Arithmetic a => (forall i m . MonadBlueprint i a m => m i) -> ArithmeticCircuit a
-- ^ Builds a circuit from blueprint. A blueprint is a function which, given an
-- arbitrary type of variables @i@ and a monad @m@ supporting the 'MonadBlueprint'
-- API, computes the output variable of a future circuit.
circuit :: forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit forall i (m :: Type -> Type). MonadBlueprint i a m => m i
b = Identity (ArithmeticCircuit a) -> ArithmeticCircuit a
forall a. Identity a -> a
runIdentity (Identity (ArithmeticCircuit a) -> ArithmeticCircuit a)
-> Identity (ArithmeticCircuit a) -> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ (forall i (m :: Type -> Type).
 MonadBlueprint i a m =>
 m (Identity i))
-> Identity (ArithmeticCircuit a)
forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits (i -> Identity i
forall a. a -> Identity a
Identity (i -> Identity i) -> m i -> m (Identity i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> m i
forall i (m :: Type -> Type). MonadBlueprint i a m => m i
b)

circuits :: (Arithmetic a, Functor f) => (forall i m . MonadBlueprint i a m => m (f i)) -> f (ArithmeticCircuit a)
-- ^ Builds a collection of circuits from one blueprint. A blueprint is a function
-- which, given an arbitrary type of variables @i@ and a monad @m@ supporting the
-- 'MonadBlueprint' API, computes the collection of output variables of future circuits.
circuits :: forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i)
b = let (f Natural
os, ArithmeticCircuit a
r) = State (ArithmeticCircuit a) (f Natural)
-> ArithmeticCircuit a -> (f Natural, ArithmeticCircuit a)
forall s a. State s a -> s -> (a, s)
runState State (ArithmeticCircuit a) (f Natural)
forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i)
b ArithmeticCircuit a
forall a. Monoid a => a
mempty in (\Natural
o -> ArithmeticCircuit a
r { acOutput = o }) (Natural -> ArithmeticCircuit a)
-> f Natural -> f (ArithmeticCircuit a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f Natural
os

sources :: forall a i . (FiniteField a, Ord i) => Witness i a -> Set i
sources :: forall a i. (FiniteField a, Ord i) => Witness i a -> Set i
sources = Sources a i -> Set i
forall {k} (a :: k) i. Sources a i -> Set i
runSources (Sources a i -> Set i)
-> (((i -> Sources a i) -> Sources a i) -> Sources a i)
-> ((i -> Sources a i) -> Sources a i)
-> Set i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((i -> Sources a i) -> Sources a i)
-> (i -> Sources a i) -> Sources a i
forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k) i. Set i -> Sources a i
forall a i. Set i -> Sources a i
Sources @a (Set i -> Sources a i) -> (i -> Set i) -> i -> Sources a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> Set i
forall a. a -> Set a
Set.singleton)

instance Ord i => Eq (Bool (Sources a i)) (Sources a i) where
  Sources a i
x == :: Sources a i -> Sources a i -> Bool (Sources a i)
== Sources a i
y = Sources a i -> Bool (Sources a i)
forall x. x -> Bool x
Bool (Sources a i
x Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
y)
  Sources a i
x /= :: Sources a i -> Sources a i -> Bool (Sources a i)
/= Sources a i
y = Sources a i -> Bool (Sources a i)
forall x. x -> Bool x
Bool (Sources a i
x Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
y)

instance (Finite a, Ord i) => Conditional (Bool (Sources a i)) (Sources a i) where
  bool :: Sources a i -> Sources a i -> Bool (Sources a i) -> Sources a i
bool Sources a i
x Sources a i
y (Bool Sources a i
b) = Sources a i
x Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
y Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
b

instance (Finite a, Ord i) => Conditional (Bool (Sources a i)) (Bool (Sources a i)) where
  bool :: Bool (Sources a i)
-> Bool (Sources a i) -> Bool (Sources a i) -> Bool (Sources a i)
bool (Bool Sources a i
x) (Bool Sources a i
y) (Bool Sources a i
b) = Sources a i -> Bool (Sources a i)
forall x. x -> Bool x
Bool (Sources a i
x Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
y Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
b)