{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint (
ClosedPoly,
MonadBlueprint (..),
NewConstraint,
Witness,
WitnessField,
circuit,
circuits,
) where
import Control.Monad.Identity (Identity (..))
import Control.Monad.State (State, gets, modify, runState)
import Data.Functor (($>))
import Data.Map ((!))
import Data.Set (Set)
import qualified Data.Set as Set
import Numeric.Natural (Natural)
import Prelude hiding (Bool (..), Eq (..), replicate, (*), (+),
(-))
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Sources
import ZkFold.Base.Algebra.Polynomials.Multivariate (var)
import qualified ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal as I
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal hiding (constraint)
import ZkFold.Symbolic.Data.Bool (Bool (..))
import ZkFold.Symbolic.Data.Conditional (Conditional (..))
import ZkFold.Symbolic.Data.Eq (Eq (..))
type WitnessField a x = (Algebra a x, FiniteField x, BinaryExpansion x,
Eq (Bool x) x, Conditional (Bool x) x, Conditional (Bool x) (Bool x))
type Witness i a = forall x . WitnessField a x => (i -> x) -> x
type NewConstraint i a = forall x . Algebra a x => (i -> x) -> i -> x
type ClosedPoly i a = forall x . Algebra a x => (i -> x) -> x
class Monad m => MonadBlueprint i a m | m -> i, m -> a where
input :: m i
output :: i -> m (ArithmeticCircuit a)
runCircuit :: ArithmeticCircuit a -> m i
newConstrained :: NewConstraint i a -> Witness i a -> m i
constraint :: ClosedPoly i a -> m ()
newAssigned :: ClosedPoly i a -> m i
newAssigned ClosedPoly i a
p = NewConstraint i a -> Witness i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
NewConstraint i a -> Witness i a -> m i
newConstrained (\i -> x
x i
i -> (i -> x) -> x
ClosedPoly i a
p i -> x
x x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
i) (i -> x) -> x
ClosedPoly i a
Witness i a
p
instance Arithmetic a => MonadBlueprint Natural a (State (ArithmeticCircuit a)) where
input :: State (ArithmeticCircuit a) Natural
input = ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acOutput (ArithmeticCircuit a -> Natural)
-> State (ArithmeticCircuit a) (ArithmeticCircuit a)
-> State (ArithmeticCircuit a) Natural
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> State (ArithmeticCircuit a) (ArithmeticCircuit a)
forall a. State (ArithmeticCircuit a) (ArithmeticCircuit a)
I.input
output :: Natural -> State (ArithmeticCircuit a) (ArithmeticCircuit a)
output Natural
i = (ArithmeticCircuit a -> ArithmeticCircuit a)
-> State (ArithmeticCircuit a) (ArithmeticCircuit a)
forall s (m :: Type -> Type) a. MonadState s m => (s -> a) -> m a
gets (\ArithmeticCircuit a
r -> ArithmeticCircuit a
r { acOutput = i })
runCircuit :: ArithmeticCircuit a -> State (ArithmeticCircuit a) Natural
runCircuit ArithmeticCircuit a
r = (ArithmeticCircuit a -> ArithmeticCircuit a)
-> State (ArithmeticCircuit a) ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify (ArithmeticCircuit a -> ArithmeticCircuit a -> ArithmeticCircuit a
forall a. Semigroup a => a -> a -> a
<> ArithmeticCircuit a
r) State (ArithmeticCircuit a) ()
-> Natural -> State (ArithmeticCircuit a) Natural
forall (f :: Type -> Type) a b. Functor f => f a -> b -> f b
$> ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acOutput ArithmeticCircuit a
r
newConstrained
:: NewConstraint Natural a
-> Witness Natural a
-> State (ArithmeticCircuit a) Natural
newConstrained :: NewConstraint Natural a
-> Witness Natural a -> State (ArithmeticCircuit a) Natural
newConstrained NewConstraint Natural a
new Witness Natural a
witness = do
let ws :: Set Natural
ws = forall a i. (FiniteField a, Ord i) => Witness i a -> Set i
sources @a (Natural -> x) -> x
Witness Natural a
witness
x :: Natural
x = Set Natural -> Natural
forall a. Ord a => Set a -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
maximum ((Natural -> Natural) -> Set Natural -> Set Natural
forall a b. (a -> b) -> Set a -> Set b
Set.mapMonotonic (Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+Natural
1) Set Natural
ws Set Natural -> Set Natural -> Set Natural
forall a. Semigroup a => a -> a -> a
<> Natural -> Set Natural
forall a. a -> Set a
Set.singleton Natural
0)
s :: Set Natural
s = Set Natural
ws Set Natural -> Set Natural -> Set Natural
forall a. Ord a => Set a -> Set a -> Set a
`Set.difference` forall a i. (FiniteField a, Ord i) => Witness i a -> Set i
sources @a ((Natural -> x) -> Natural -> x
NewConstraint Natural a
`new` Natural
x)
Natural
i <- Natural -> State (ArithmeticCircuit a) Natural
forall a. Natural -> State (ArithmeticCircuit a) Natural
addVariable (Natural -> State (ArithmeticCircuit a) Natural)
-> State (ArithmeticCircuit a) Natural
-> State (ArithmeticCircuit a) Natural
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Natural]
-> (Natural -> Constraint a) -> State (ArithmeticCircuit a) Natural
forall a.
Arithmetic a =>
[Natural]
-> (Natural -> Constraint a) -> State (ArithmeticCircuit a) Natural
newVariableWithSource (Set Natural -> [Natural]
forall a. Set a -> [a]
Set.toList Set Natural
s) ((Natural -> Constraint a) -> Natural -> Constraint a
NewConstraint Natural a
new Natural -> Constraint a
forall c i j.
Polynomial c i j =>
i -> P c i j (Map i j) [(c, M i j (Map i j))]
var)
ClosedPoly Natural a -> State (ArithmeticCircuit a) ()
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m ()
constraint ((Natural -> x) -> Natural -> x
NewConstraint Natural a
`new` Natural
i)
(Map Natural a -> a) -> State (ArithmeticCircuit a) ()
forall a. (Map Natural a -> a) -> State (ArithmeticCircuit a) ()
assignment (\Map Natural a
m -> (Natural -> a) -> a
Witness Natural a
witness (Map Natural a
m Map Natural a -> Natural -> a
forall k a. Ord k => Map k a -> k -> a
!))
Natural -> State (ArithmeticCircuit a) Natural
forall a. a -> StateT (ArithmeticCircuit a) Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Natural
i
constraint :: ClosedPoly Natural a -> State (ArithmeticCircuit a) ()
constraint ClosedPoly Natural a
p = Constraint a -> State (ArithmeticCircuit a) ()
forall a.
Arithmetic a =>
Constraint a -> State (ArithmeticCircuit a) ()
I.constraint ((Natural -> Constraint a) -> Constraint a
ClosedPoly Natural a
p Natural -> Constraint a
forall c i j.
Polynomial c i j =>
i -> P c i j (Map i j) [(c, M i j (Map i j))]
var)
circuit :: Arithmetic a => (forall i m . MonadBlueprint i a m => m i) -> ArithmeticCircuit a
circuit :: forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit forall i (m :: Type -> Type). MonadBlueprint i a m => m i
b = Identity (ArithmeticCircuit a) -> ArithmeticCircuit a
forall a. Identity a -> a
runIdentity (Identity (ArithmeticCircuit a) -> ArithmeticCircuit a)
-> Identity (ArithmeticCircuit a) -> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ (forall i (m :: Type -> Type).
MonadBlueprint i a m =>
m (Identity i))
-> Identity (ArithmeticCircuit a)
forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits (i -> Identity i
forall a. a -> Identity a
Identity (i -> Identity i) -> m i -> m (Identity i)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> m i
forall i (m :: Type -> Type). MonadBlueprint i a m => m i
b)
circuits :: (Arithmetic a, Functor f) => (forall i m . MonadBlueprint i a m => m (f i)) -> f (ArithmeticCircuit a)
circuits :: forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i)
b = let (f Natural
os, ArithmeticCircuit a
r) = State (ArithmeticCircuit a) (f Natural)
-> ArithmeticCircuit a -> (f Natural, ArithmeticCircuit a)
forall s a. State s a -> s -> (a, s)
runState State (ArithmeticCircuit a) (f Natural)
forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i)
b ArithmeticCircuit a
forall a. Monoid a => a
mempty in (\Natural
o -> ArithmeticCircuit a
r { acOutput = o }) (Natural -> ArithmeticCircuit a)
-> f Natural -> f (ArithmeticCircuit a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> f Natural
os
sources :: forall a i . (FiniteField a, Ord i) => Witness i a -> Set i
sources :: forall a i. (FiniteField a, Ord i) => Witness i a -> Set i
sources = Sources a i -> Set i
forall {k} (a :: k) i. Sources a i -> Set i
runSources (Sources a i -> Set i)
-> (((i -> Sources a i) -> Sources a i) -> Sources a i)
-> ((i -> Sources a i) -> Sources a i)
-> Set i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((i -> Sources a i) -> Sources a i)
-> (i -> Sources a i) -> Sources a i
forall a b. (a -> b) -> a -> b
$ forall {k} (a :: k) i. Set i -> Sources a i
forall a i. Set i -> Sources a i
Sources @a (Set i -> Sources a i) -> (i -> Set i) -> i -> Sources a i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> Set i
forall a. a -> Set a
Set.singleton)
instance Ord i => Eq (Bool (Sources a i)) (Sources a i) where
Sources a i
x == :: Sources a i -> Sources a i -> Bool (Sources a i)
== Sources a i
y = Sources a i -> Bool (Sources a i)
forall x. x -> Bool x
Bool (Sources a i
x Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
y)
Sources a i
x /= :: Sources a i -> Sources a i -> Bool (Sources a i)
/= Sources a i
y = Sources a i -> Bool (Sources a i)
forall x. x -> Bool x
Bool (Sources a i
x Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
y)
instance (Finite a, Ord i) => Conditional (Bool (Sources a i)) (Sources a i) where
bool :: Sources a i -> Sources a i -> Bool (Sources a i) -> Sources a i
bool Sources a i
x Sources a i
y (Bool Sources a i
b) = Sources a i
x Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
y Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
b
instance (Finite a, Ord i) => Conditional (Bool (Sources a i)) (Bool (Sources a i)) where
bool :: Bool (Sources a i)
-> Bool (Sources a i) -> Bool (Sources a i) -> Bool (Sources a i)
bool (Bool Sources a i
x) (Bool Sources a i
y) (Bool Sources a i
b) = Sources a i -> Bool (Sources a i)
forall x. x -> Bool x
Bool (Sources a i
x Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
y Sources a i -> Sources a i -> Sources a i
forall a. Semigroup a => a -> a -> a
<> Sources a i
b)