{-# LANGUAGE DeriveAnyClass   #-}
{-# LANGUAGE TypeApplications #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (
        ArithmeticCircuit(..),
        Arithmetic,
        ConstraintMonomial,
        Constraint,
        -- low-level functions
        constraint,
        assignment,
        addVariable,
        newVariableWithSource,
        input,
        eval,
        apply,
        forceZero
    ) where

import           Control.DeepSeq                              (NFData)
import           Control.Monad.State                          (MonadState (..), State, modify)
import           Data.List                                    (nub)
import           Data.Map.Strict                              hiding (drop, foldl, foldr, map, null, splitAt, take)
import           GHC.Generics
import           Numeric.Natural                              (Natural)
import           Optics
import           Prelude                                      hiding (Num (..), drop, length, product, splitAt, sum,
                                                               take, (!!), (^))
import qualified Prelude                                      as Haskell
import           System.Random                                (StdGen, mkStdGen, uniform, uniformR)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field              (Zp, fromZp, toZp)
import           ZkFold.Base.Algebra.EllipticCurve.BLS12_381  (BLS12_381_Scalar)
import           ZkFold.Base.Algebra.Polynomials.Multivariate (Monomial', Polynomial', evalMapM, evalPolynomial,
                                                               mapCoeffs, var)
import           ZkFold.Prelude                               (drop, length)

-- | Arithmetic circuit in the form of a system of polynomial constraints.
data ArithmeticCircuit a = ArithmeticCircuit
    {
        forall a. ArithmeticCircuit a -> Map Natural (Constraint a)
acSystem   :: Map Natural (Constraint a),
        -- ^ The system of polynomial constraints
        forall a. ArithmeticCircuit a -> [Natural]
acInput    :: [Natural],
        -- ^ The input variables
        forall a. ArithmeticCircuit a -> Map Natural a -> Map Natural a
acWitness  :: Map Natural a -> Map Natural a,
        -- ^ The witness generation function
        forall a. ArithmeticCircuit a -> Natural
acOutput   :: Natural,
        -- ^ The output variable
        forall a. ArithmeticCircuit a -> Map (Natural, Natural) Natural
acVarOrder :: Map (Natural, Natural) Natural,
        -- ^ The order of variable assignments
        forall a. ArithmeticCircuit a -> StdGen
acRNG      :: StdGen
    } deriving ((forall x. ArithmeticCircuit a -> Rep (ArithmeticCircuit a) x)
-> (forall x. Rep (ArithmeticCircuit a) x -> ArithmeticCircuit a)
-> Generic (ArithmeticCircuit a)
forall x. Rep (ArithmeticCircuit a) x -> ArithmeticCircuit a
forall x. ArithmeticCircuit a -> Rep (ArithmeticCircuit a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (ArithmeticCircuit a) x -> ArithmeticCircuit a
forall a x. ArithmeticCircuit a -> Rep (ArithmeticCircuit a) x
$cfrom :: forall a x. ArithmeticCircuit a -> Rep (ArithmeticCircuit a) x
from :: forall x. ArithmeticCircuit a -> Rep (ArithmeticCircuit a) x
$cto :: forall a x. Rep (ArithmeticCircuit a) x -> ArithmeticCircuit a
to :: forall x. Rep (ArithmeticCircuit a) x -> ArithmeticCircuit a
Generic, ArithmeticCircuit a -> ()
(ArithmeticCircuit a -> ()) -> NFData (ArithmeticCircuit a)
forall a. NFData a => ArithmeticCircuit a -> ()
forall a. (a -> ()) -> NFData a
$crnf :: forall a. NFData a => ArithmeticCircuit a -> ()
rnf :: ArithmeticCircuit a -> ()
NFData)

----------------------------------- Circuit monoid ----------------------------------

instance Eq a => Semigroup (ArithmeticCircuit a) where
    ArithmeticCircuit a
r1 <> :: ArithmeticCircuit a -> ArithmeticCircuit a -> ArithmeticCircuit a
<> ArithmeticCircuit a
r2 = ArithmeticCircuit
        {
            acSystem :: Map Natural (Constraint a)
acSystem   = ArithmeticCircuit a -> Map Natural (Constraint a)
forall a. ArithmeticCircuit a -> Map Natural (Constraint a)
acSystem ArithmeticCircuit a
r1 Map Natural (Constraint a)
-> Map Natural (Constraint a) -> Map Natural (Constraint a)
forall k a. Ord k => Map k a -> Map k a -> Map k a
`union` ArithmeticCircuit a -> Map Natural (Constraint a)
forall a. ArithmeticCircuit a -> Map Natural (Constraint a)
acSystem ArithmeticCircuit a
r2,
            -- NOTE: is it possible that we get a wrong argument order when doing `apply` because of this concatenation?
            -- We need a way to ensure the correct order no matter how `(<>)` is used.
            acInput :: [Natural]
acInput    = [Natural] -> [Natural]
forall a. Eq a => [a] -> [a]
nub ([Natural] -> [Natural]) -> [Natural] -> [Natural]
forall a b. (a -> b) -> a -> b
$ ArithmeticCircuit a -> [Natural]
forall a. ArithmeticCircuit a -> [Natural]
acInput ArithmeticCircuit a
r1 [Natural] -> [Natural] -> [Natural]
forall a. [a] -> [a] -> [a]
++ ArithmeticCircuit a -> [Natural]
forall a. ArithmeticCircuit a -> [Natural]
acInput ArithmeticCircuit a
r2,
            acWitness :: Map Natural a -> Map Natural a
acWitness  = Map Natural a -> Map Natural a -> Map Natural a
forall k a. Ord k => Map k a -> Map k a -> Map k a
union (Map Natural a -> Map Natural a -> Map Natural a)
-> (Map Natural a -> Map Natural a)
-> Map Natural a
-> Map Natural a
-> Map Natural a
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a -> Map Natural a -> Map Natural a
forall a. ArithmeticCircuit a -> Map Natural a -> Map Natural a
acWitness ArithmeticCircuit a
r1 (Map Natural a -> Map Natural a -> Map Natural a)
-> (Map Natural a -> Map Natural a)
-> Map Natural a
-> Map Natural a
forall a b.
(Map Natural a -> a -> b)
-> (Map Natural a -> a) -> Map Natural a -> b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ArithmeticCircuit a -> Map Natural a -> Map Natural a
forall a. ArithmeticCircuit a -> Map Natural a -> Map Natural a
acWitness ArithmeticCircuit a
r2,
            acOutput :: Natural
acOutput   = Natural -> Natural -> Natural
forall a. Ord a => a -> a -> a
max (ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acOutput ArithmeticCircuit a
r1) (ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acOutput ArithmeticCircuit a
r2),
            acVarOrder :: Map (Natural, Natural) Natural
acVarOrder = ArithmeticCircuit a -> Map (Natural, Natural) Natural
forall a. ArithmeticCircuit a -> Map (Natural, Natural) Natural
acVarOrder ArithmeticCircuit a
r1 Map (Natural, Natural) Natural
-> Map (Natural, Natural) Natural -> Map (Natural, Natural) Natural
forall k a. Ord k => Map k a -> Map k a -> Map k a
`union` ArithmeticCircuit a -> Map (Natural, Natural) Natural
forall a. ArithmeticCircuit a -> Map (Natural, Natural) Natural
acVarOrder ArithmeticCircuit a
r2,
            acRNG :: StdGen
acRNG      = Int -> StdGen
mkStdGen (Int -> StdGen) -> Int -> StdGen
forall a b. (a -> b) -> a -> b
$ (Int, StdGen) -> Int
forall a b. (a, b) -> a
fst (StdGen -> (Int, StdGen)
forall g a. (RandomGen g, Uniform a) => g -> (a, g)
uniform (ArithmeticCircuit a -> StdGen
forall a. ArithmeticCircuit a -> StdGen
acRNG ArithmeticCircuit a
r1)) Int -> Int -> Int
forall a. Num a => a -> a -> a
Haskell.* (Int, StdGen) -> Int
forall a b. (a, b) -> a
fst (StdGen -> (Int, StdGen)
forall g a. (RandomGen g, Uniform a) => g -> (a, g)
uniform (ArithmeticCircuit a -> StdGen
forall a. ArithmeticCircuit a -> StdGen
acRNG ArithmeticCircuit a
r2))
        }

instance (FiniteField a, Eq a) => Monoid (ArithmeticCircuit a) where
    mempty :: ArithmeticCircuit a
mempty = ArithmeticCircuit
        {
            acSystem :: Map Natural (Constraint a)
acSystem   = Map Natural (Constraint a)
forall k a. Map k a
empty,
            acInput :: [Natural]
acInput    = [],
            acWitness :: Map Natural a -> Map Natural a
acWitness  = Natural -> a -> Map Natural a -> Map Natural a
forall k a. Ord k => k -> a -> Map k a -> Map k a
insert Natural
0 a
forall a. MultiplicativeMonoid a => a
one,
            acOutput :: Natural
acOutput   = Natural
0,
            acVarOrder :: Map (Natural, Natural) Natural
acVarOrder = Map (Natural, Natural) Natural
forall k a. Map k a
empty,
            acRNG :: StdGen
acRNG      = Int -> StdGen
mkStdGen Int
0
        }

------------------------------------- Variables -------------------------------------

-- | A finite field of a large order.
-- It is used in the compiler for generating new variable indices.
type VarField = Zp BLS12_381_Scalar

toField :: Arithmetic a => a -> VarField
toField :: forall a. Arithmetic a => a -> VarField
toField = Integer -> VarField
forall (p :: Natural). KnownNat p => Integer -> Zp p
toZp (Integer -> VarField) -> (a -> Integer) -> a -> VarField
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Integer
forall a b. FromConstant a b => a -> b
fromConstant (Natural -> Integer) -> (a -> Natural) -> a -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. BinaryExpansion a => [a] -> a
fromBinary @Natural ([Natural] -> Natural) -> (a -> [Natural]) -> a -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [Natural]
forall a b. (Semiring a, Eq a, Semiring b) => [a] -> [b]
castBits ([a] -> [Natural]) -> (a -> [a]) -> a -> [Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> [a]
forall a. BinaryExpansion a => a -> [a]
binaryExpansion

type Arithmetic a = (FiniteField a, Eq a, BinaryExpansion a)

-- TODO: Remove the hardcoded constant.
toVar :: forall a . Arithmetic a => [Natural] -> Constraint a -> Natural
toVar :: forall a. Arithmetic a => [Natural] -> Constraint a -> Natural
toVar [Natural]
srcs Constraint a
c = VarField -> Natural
forall (p :: Natural). Zp p -> Natural
fromZp VarField
ex
    where
        r :: VarField
r  = Integer -> VarField
forall (p :: Natural). KnownNat p => Integer -> Zp p
toZp Integer
903489679376934896793395274328947923579382759823 :: VarField
        g :: VarField
g  = Integer -> VarField
forall (p :: Natural). KnownNat p => Integer -> Zp p
toZp Integer
89175291725091202781479751781509570912743212325 :: VarField
        v :: Natural -> VarField
v  = (VarField -> VarField -> VarField
forall a. AdditiveSemigroup a => a -> a -> a
+ VarField
r) (VarField -> VarField)
-> (Natural -> VarField) -> Natural -> VarField
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> VarField
forall a b. FromConstant a b => a -> b
fromConstant
        x :: VarField
x  = VarField
g VarField -> Natural -> VarField
forall a b. Exponent a b => a -> b -> a
^ VarField -> Natural
forall (p :: Natural). Zp p -> Natural
fromZp (((Natural -> VarField) -> Monomial' -> VarField)
-> (Natural -> VarField)
-> P VarField
     Natural
     Natural
     (Map Natural Natural)
     [(VarField, Monomial')]
-> VarField
forall {k1} c i (j :: k1) b m.
Algebra c b =>
((i -> b) -> M i j m -> b)
-> (i -> b) -> P c i j m [(c, M i j m)] -> b
evalPolynomial (Natural -> VarField) -> Monomial' -> VarField
forall i j b.
(MultiplicativeMonoid b, Exponent b j) =>
(i -> b) -> M i j (Map i j) -> b
evalMapM Natural -> VarField
v (P VarField
   Natural
   Natural
   (Map Natural Natural)
   [(VarField, Monomial')]
 -> VarField)
-> P VarField
     Natural
     Natural
     (Map Natural Natural)
     [(VarField, Monomial')]
-> VarField
forall a b. (a -> b) -> a -> b
$ (a -> VarField)
-> Constraint a
-> P VarField
     Natural
     Natural
     (Map Natural Natural)
     [(VarField, Monomial')]
forall c c' i j.
(c -> c')
-> P c i j (Map i j) [(c, M i j (Map i j))]
-> P c' i j (Map i j) [(c', M i j (Map i j))]
mapCoeffs a -> VarField
forall a. Arithmetic a => a -> VarField
toField Constraint a
c)
        ex :: VarField
ex = (Natural -> VarField -> VarField)
-> VarField -> [Natural] -> VarField
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Natural
p VarField
y -> VarField
x VarField -> Natural -> VarField
forall a b. Exponent a b => a -> b -> a
^ Natural
p VarField -> VarField -> VarField
forall a. AdditiveSemigroup a => a -> a -> a
+ VarField
y) VarField
x [Natural]
srcs

newVariableWithSource :: Arithmetic a => [Natural] -> (Natural -> Constraint a) -> State (ArithmeticCircuit a) Natural
newVariableWithSource :: forall a.
Arithmetic a =>
[Natural]
-> (Natural -> Constraint a) -> State (ArithmeticCircuit a) Natural
newVariableWithSource [Natural]
srcs Natural -> Constraint a
con = [Natural] -> Constraint a -> Natural
forall a. Arithmetic a => [Natural] -> Constraint a -> Natural
toVar [Natural]
srcs (Constraint a -> Natural)
-> ((Natural, ()) -> Constraint a) -> (Natural, ()) -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Constraint a
con (Natural -> Constraint a)
-> ((Natural, ()) -> Natural) -> (Natural, ()) -> Constraint a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural, ()) -> Natural
forall a b. (a, b) -> a
fst ((Natural, ()) -> Natural)
-> StateT (ArithmeticCircuit a) Identity (Natural, ())
-> StateT (ArithmeticCircuit a) Identity Natural
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> do
    Optic' A_Lens NoIx (ArithmeticCircuit a) StdGen
-> StateT StdGen Identity (Natural, ())
-> StateT (ArithmeticCircuit a) Identity (Natural, ())
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) StdGen
-> StateT StdGen Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) StdGen
#acRNG (StateT StdGen Identity (Natural, ())
 -> StateT (ArithmeticCircuit a) Identity (Natural, ()))
-> StateT StdGen Identity (Natural, ())
-> StateT (ArithmeticCircuit a) Identity (Natural, ())
forall a b. (a -> b) -> a -> b
$ StateT StdGen Identity StdGen
forall s (m :: Type -> Type). MonadState s m => m s
get StateT StdGen Identity StdGen
-> (StdGen -> StateT StdGen Identity (Natural, ()))
-> StateT StdGen Identity (Natural, ())
forall a b.
StateT StdGen Identity a
-> (a -> StateT StdGen Identity b) -> StateT StdGen Identity b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= (StdGen -> StateT StdGen Identity ())
-> (Natural, StdGen) -> StateT StdGen Identity (Natural, ())
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> (Natural, a) -> f (Natural, b)
traverse StdGen -> StateT StdGen Identity ()
forall s (m :: Type -> Type). MonadState s m => s -> m ()
put ((Natural, StdGen) -> StateT StdGen Identity (Natural, ()))
-> (StdGen -> (Natural, StdGen))
-> StdGen
-> StateT StdGen Identity (Natural, ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural, Natural) -> StdGen -> (Natural, StdGen)
forall g a. (RandomGen g, UniformRange a) => (a, a) -> g -> (a, g)
uniformR (Natural
0, forall a. Finite a => Natural
order @VarField Natural -> Natural -> Natural
-! Natural
1)

addVariable :: Natural -> State (ArithmeticCircuit a) Natural
addVariable :: forall a. Natural -> State (ArithmeticCircuit a) Natural
addVariable Natural
x = do
    Optic' A_Lens NoIx (ArithmeticCircuit a) Natural
-> StateT Natural Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) Natural
-> StateT Natural Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) Natural
#acOutput (StateT Natural Identity ()
 -> StateT (ArithmeticCircuit a) Identity ())
-> StateT Natural Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall a b. (a -> b) -> a -> b
$ Natural -> StateT Natural Identity ()
forall s (m :: Type -> Type). MonadState s m => s -> m ()
put Natural
x
    Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map (Natural, Natural) Natural)
-> StateT (Map (Natural, Natural) Natural) Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) (Map (Natural, Natural) Natural)
-> StateT (Map (Natural, Natural) Natural) Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map (Natural, Natural) Natural)
#acVarOrder (StateT (Map (Natural, Natural) Natural) Identity ()
 -> StateT (ArithmeticCircuit a) Identity ())
-> ((Map (Natural, Natural) Natural
     -> Map (Natural, Natural) Natural)
    -> StateT (Map (Natural, Natural) Natural) Identity ())
-> (Map (Natural, Natural) Natural
    -> Map (Natural, Natural) Natural)
-> StateT (ArithmeticCircuit a) Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map (Natural, Natural) Natural -> Map (Natural, Natural) Natural)
-> StateT (Map (Natural, Natural) Natural) Identity ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify
        ((Map (Natural, Natural) Natural -> Map (Natural, Natural) Natural)
 -> StateT (ArithmeticCircuit a) Identity ())
-> (Map (Natural, Natural) Natural
    -> Map (Natural, Natural) Natural)
-> StateT (ArithmeticCircuit a) Identity ()
forall a b. (a -> b) -> a -> b
$ \Map (Natural, Natural) Natural
vo -> (Natural, Natural)
-> Natural
-> Map (Natural, Natural) Natural
-> Map (Natural, Natural) Natural
forall k a. Ord k => k -> a -> Map k a -> Map k a
insert (Map (Natural, Natural) Natural -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length Map (Natural, Natural) Natural
vo, Natural
x) Natural
x Map (Natural, Natural) Natural
vo
    Natural -> State (ArithmeticCircuit a) Natural
forall a. a -> StateT (ArithmeticCircuit a) Identity a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Natural
x

---------------------------------- Low-level functions --------------------------------

type ConstraintMonomial = Monomial'

-- | The type that represents a constraint in the arithmetic circuit.
type Constraint c = Polynomial' c

-- | Adds a constraint to the arithmetic circuit.
constraint :: Arithmetic a => Constraint a -> State (ArithmeticCircuit a) ()
constraint :: forall a.
Arithmetic a =>
Constraint a -> State (ArithmeticCircuit a) ()
constraint Constraint a
c = Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map Natural (Constraint a))
-> StateT (Map Natural (Constraint a)) Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) (Map Natural (Constraint a))
-> StateT (Map Natural (Constraint a)) Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map Natural (Constraint a))
#acSystem (StateT (Map Natural (Constraint a)) Identity ()
 -> StateT (ArithmeticCircuit a) Identity ())
-> ((Map Natural (Constraint a) -> Map Natural (Constraint a))
    -> StateT (Map Natural (Constraint a)) Identity ())
-> (Map Natural (Constraint a) -> Map Natural (Constraint a))
-> StateT (ArithmeticCircuit a) Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map Natural (Constraint a) -> Map Natural (Constraint a))
-> StateT (Map Natural (Constraint a)) Identity ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify ((Map Natural (Constraint a) -> Map Natural (Constraint a))
 -> StateT (ArithmeticCircuit a) Identity ())
-> (Map Natural (Constraint a) -> Map Natural (Constraint a))
-> StateT (ArithmeticCircuit a) Identity ()
forall a b. (a -> b) -> a -> b
$ Natural
-> Constraint a
-> Map Natural (Constraint a)
-> Map Natural (Constraint a)
forall k a. Ord k => k -> a -> Map k a -> Map k a
insert ([Natural] -> Constraint a -> Natural
forall a. Arithmetic a => [Natural] -> Constraint a -> Natural
toVar [] Constraint a
c) Constraint a
c

-- | Forces the current variable to be zero.
forceZero :: forall a . Arithmetic a => State (ArithmeticCircuit a) ()
forceZero :: forall a. Arithmetic a => State (ArithmeticCircuit a) ()
forceZero = Optic' A_Lens NoIx (ArithmeticCircuit a) Natural
-> StateT Natural Identity Natural
-> StateT (ArithmeticCircuit a) Identity Natural
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) Natural
-> StateT Natural Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) Natural
#acOutput StateT Natural Identity Natural
forall s (m :: Type -> Type). MonadState s m => m s
get StateT (ArithmeticCircuit a) Identity Natural
-> (Natural -> StateT (ArithmeticCircuit a) Identity ())
-> StateT (ArithmeticCircuit a) Identity ()
forall a b.
StateT (ArithmeticCircuit a) Identity a
-> (a -> StateT (ArithmeticCircuit a) Identity b)
-> StateT (ArithmeticCircuit a) Identity b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Constraint a -> StateT (ArithmeticCircuit a) Identity ()
forall a.
Arithmetic a =>
Constraint a -> State (ArithmeticCircuit a) ()
constraint (Constraint a -> StateT (ArithmeticCircuit a) Identity ())
-> (Natural -> Constraint a)
-> Natural
-> StateT (ArithmeticCircuit a) Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> Constraint a
forall c i j.
Polynomial c i j =>
i -> P c i j (Map i j) [(c, M i j (Map i j))]
var

-- | Adds a new variable assignment to the arithmetic circuit.
-- TODO: forbid reassignment of variables
assignment :: (Map Natural a -> a) -> State (ArithmeticCircuit a) ()
assignment :: forall a. (Map Natural a -> a) -> State (ArithmeticCircuit a) ()
assignment Map Natural a -> a
f = do
    a -> Map Natural a -> Map Natural a
i <- Natural -> a -> Map Natural a -> Map Natural a
forall k a. Ord k => k -> a -> Map k a -> Map k a
insert (Natural -> a -> Map Natural a -> Map Natural a)
-> StateT (ArithmeticCircuit a) Identity Natural
-> StateT
     (ArithmeticCircuit a)
     Identity
     (a -> Map Natural a -> Map Natural a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Optic' A_Lens NoIx (ArithmeticCircuit a) Natural
-> StateT Natural Identity Natural
-> StateT (ArithmeticCircuit a) Identity Natural
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) Natural
-> StateT Natural Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) Natural
#acOutput StateT Natural Identity Natural
forall s (m :: Type -> Type). MonadState s m => m s
get
    Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map Natural a -> Map Natural a)
-> StateT (Map Natural a -> Map Natural a) Identity ()
-> State (ArithmeticCircuit a) ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) (Map Natural a -> Map Natural a)
-> StateT (Map Natural a -> Map Natural a) Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map Natural a -> Map Natural a)
#acWitness (StateT (Map Natural a -> Map Natural a) Identity ()
 -> State (ArithmeticCircuit a) ())
-> (((Map Natural a -> Map Natural a)
     -> Map Natural a -> Map Natural a)
    -> StateT (Map Natural a -> Map Natural a) Identity ())
-> ((Map Natural a -> Map Natural a)
    -> Map Natural a -> Map Natural a)
-> State (ArithmeticCircuit a) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Map Natural a -> Map Natural a)
 -> Map Natural a -> Map Natural a)
-> StateT (Map Natural a -> Map Natural a) Identity ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify (((Map Natural a -> Map Natural a)
  -> Map Natural a -> Map Natural a)
 -> State (ArithmeticCircuit a) ())
-> ((Map Natural a -> Map Natural a)
    -> Map Natural a -> Map Natural a)
-> State (ArithmeticCircuit a) ()
forall a b. (a -> b) -> a -> b
$ (Map Natural a -> Map Natural a)
-> (Map Natural a -> Map Natural a)
-> Map Natural a
-> Map Natural a
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) (\Map Natural a
m -> a -> Map Natural a -> Map Natural a
i (Map Natural a -> a
f Map Natural a
m) Map Natural a
m)

-- | Adds a new input variable to the arithmetic circuit. Returns a copy of the arithmetic circuit with this variable as output.
input :: forall a . State (ArithmeticCircuit a) (ArithmeticCircuit a)
input :: forall a. State (ArithmeticCircuit a) (ArithmeticCircuit a)
input = do
  [Natural]
inputs <- Optic' A_Lens NoIx (ArithmeticCircuit a) [Natural]
-> StateT [Natural] Identity [Natural]
-> StateT (ArithmeticCircuit a) Identity [Natural]
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) [Natural]
-> StateT [Natural] Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) [Natural]
#acInput StateT [Natural] Identity [Natural]
forall s (m :: Type -> Type). MonadState s m => m s
get
  let s :: Natural
s = if [Natural] -> Bool
forall a. [a] -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null [Natural]
inputs then Natural
1 else [Natural] -> Natural
forall a. Ord a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
maximum [Natural]
inputs Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
1
  Optic' A_Lens NoIx (ArithmeticCircuit a) [Natural]
-> StateT [Natural] Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) [Natural]
-> StateT [Natural] Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) [Natural]
#acInput (StateT [Natural] Identity ()
 -> StateT (ArithmeticCircuit a) Identity ())
-> StateT [Natural] Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall a b. (a -> b) -> a -> b
$ ([Natural] -> [Natural]) -> StateT [Natural] Identity ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify ([Natural] -> [Natural] -> [Natural]
forall a. [a] -> [a] -> [a]
++ [Natural
s])
  Optic' A_Lens NoIx (ArithmeticCircuit a) Natural
-> StateT Natural Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) Natural
-> StateT Natural Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) Natural
#acOutput (StateT Natural Identity ()
 -> StateT (ArithmeticCircuit a) Identity ())
-> StateT Natural Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall a b. (a -> b) -> a -> b
$ Natural -> StateT Natural Identity ()
forall s (m :: Type -> Type). MonadState s m => s -> m ()
put Natural
s
  Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map (Natural, Natural) Natural)
-> StateT (Map (Natural, Natural) Natural) Identity ()
-> StateT (ArithmeticCircuit a) Identity ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) (Map (Natural, Natural) Natural)
-> StateT (Map (Natural, Natural) Natural) Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map (Natural, Natural) Natural)
#acVarOrder (StateT (Map (Natural, Natural) Natural) Identity ()
 -> StateT (ArithmeticCircuit a) Identity ())
-> ((Map (Natural, Natural) Natural
     -> Map (Natural, Natural) Natural)
    -> StateT (Map (Natural, Natural) Natural) Identity ())
-> (Map (Natural, Natural) Natural
    -> Map (Natural, Natural) Natural)
-> StateT (ArithmeticCircuit a) Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map (Natural, Natural) Natural -> Map (Natural, Natural) Natural)
-> StateT (Map (Natural, Natural) Natural) Identity ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify
      ((Map (Natural, Natural) Natural -> Map (Natural, Natural) Natural)
 -> StateT (ArithmeticCircuit a) Identity ())
-> (Map (Natural, Natural) Natural
    -> Map (Natural, Natural) Natural)
-> StateT (ArithmeticCircuit a) Identity ()
forall a b. (a -> b) -> a -> b
$ \Map (Natural, Natural) Natural
vo -> (Natural, Natural)
-> Natural
-> Map (Natural, Natural) Natural
-> Map (Natural, Natural) Natural
forall k a. Ord k => k -> a -> Map k a -> Map k a
insert (Map (Natural, Natural) Natural -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length Map (Natural, Natural) Natural
vo, Natural
s) Natural
s Map (Natural, Natural) Natural
vo
  State (ArithmeticCircuit a) (ArithmeticCircuit a)
forall s (m :: Type -> Type). MonadState s m => m s
get

-- | Evaluates the arithmetic circuit using the supplied input map.
eval :: ArithmeticCircuit a -> Map Natural a -> a
eval :: forall a. ArithmeticCircuit a -> Map Natural a -> a
eval ArithmeticCircuit a
ctx Map Natural a
i = ArithmeticCircuit a -> Map Natural a -> Map Natural a
forall a. ArithmeticCircuit a -> Map Natural a -> Map Natural a
acWitness ArithmeticCircuit a
ctx Map Natural a
i Map Natural a -> Natural -> a
forall k a. Ord k => Map k a -> k -> a
! ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acOutput ArithmeticCircuit a
ctx

-- | Applies the values of the first `n` inputs to the arithmetic circuit.
-- TODO: make this safe
apply :: [a] -> State (ArithmeticCircuit a) ()
apply :: forall a. [a] -> State (ArithmeticCircuit a) ()
apply [a]
xs = do
    [Natural]
inputs <- Optic' A_Lens NoIx (ArithmeticCircuit a) [Natural]
-> StateT [Natural] Identity [Natural]
-> StateT (ArithmeticCircuit a) Identity [Natural]
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) [Natural]
-> StateT [Natural] Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) [Natural]
#acInput StateT [Natural] Identity [Natural]
forall s (m :: Type -> Type). MonadState s m => m s
get
    Optic' A_Lens NoIx (ArithmeticCircuit a) [Natural]
-> StateT [Natural] Identity () -> State (ArithmeticCircuit a) ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) [Natural]
-> StateT [Natural] Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic' A_Lens NoIx (ArithmeticCircuit a) [Natural]
#acInput (StateT [Natural] Identity () -> State (ArithmeticCircuit a) ())
-> ([Natural] -> StateT [Natural] Identity ())
-> [Natural]
-> State (ArithmeticCircuit a) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Natural] -> StateT [Natural] Identity ()
forall s (m :: Type -> Type). MonadState s m => s -> m ()
put ([Natural] -> State (ArithmeticCircuit a) ())
-> [Natural] -> State (ArithmeticCircuit a) ()
forall a b. (a -> b) -> a -> b
$ Natural -> [Natural] -> [Natural]
forall a. Natural -> [a] -> [a]
drop ([a] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [a]
xs) [Natural]
inputs
    Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map Natural a -> Map Natural a)
-> StateT (Map Natural a -> Map Natural a) Identity ()
-> State (ArithmeticCircuit a) ()
forall k (is :: IxList) c.
Is k A_Lens =>
Optic' k is (ArithmeticCircuit a) (Map Natural a -> Map Natural a)
-> StateT (Map Natural a -> Map Natural a) Identity c
-> StateT (ArithmeticCircuit a) Identity c
forall (m :: Type -> Type) (n :: Type -> Type) s t k (is :: IxList)
       c.
(Zoom m n s t, Is k A_Lens) =>
Optic' k is t s -> m c -> n c
zoom Optic'
  A_Lens NoIx (ArithmeticCircuit a) (Map Natural a -> Map Natural a)
#acWitness (StateT (Map Natural a -> Map Natural a) Identity ()
 -> State (ArithmeticCircuit a) ())
-> (((Map Natural a -> Map Natural a)
     -> Map Natural a -> Map Natural a)
    -> StateT (Map Natural a -> Map Natural a) Identity ())
-> ((Map Natural a -> Map Natural a)
    -> Map Natural a -> Map Natural a)
-> State (ArithmeticCircuit a) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Map Natural a -> Map Natural a)
 -> Map Natural a -> Map Natural a)
-> StateT (Map Natural a -> Map Natural a) Identity ()
forall s (m :: Type -> Type). MonadState s m => (s -> s) -> m ()
modify (((Map Natural a -> Map Natural a)
  -> Map Natural a -> Map Natural a)
 -> State (ArithmeticCircuit a) ())
-> ((Map Natural a -> Map Natural a)
    -> Map Natural a -> Map Natural a)
-> State (ArithmeticCircuit a) ()
forall a b. (a -> b) -> a -> b
$ ((Map Natural a -> Map Natural a)
-> (Map Natural a -> Map Natural a)
-> Map Natural a
-> Map Natural a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Natural a -> Map Natural a -> Map Natural a
forall k a. Ord k => Map k a -> Map k a -> Map k a
union ([(Natural, a)] -> Map Natural a
forall k a. Ord k => [(k, a)] -> Map k a
fromList ([(Natural, a)] -> Map Natural a)
-> [(Natural, a)] -> Map Natural a
forall a b. (a -> b) -> a -> b
$ [Natural] -> [a] -> [(Natural, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Natural]
inputs [a]
xs))

-- TODO: Add proper symbolic application functions

-- applySymOne :: ArithmeticCircuit a -> State (ArithmeticCircuit a) ()
-- applySymOne x = modify (\(f :: ArithmeticCircuit a) ->
--     let ins = acInput f
--     in f
--     {
--         acInput = tail ins,
--         acWitness = acWitness f . (singleton (head ins) (eval x empty)  `union`)
--     })

-- applySym :: [ArithmeticCircuit a] -> State (ArithmeticCircuit a) ()
-- applySym = foldr ((>>) . applySymOne) (return ())

-- applySymArgs :: ArithmeticCircuit a -> [ArithmeticCircuit a] -> ArithmeticCircuit a
-- applySymArgs x xs = execState (applySym xs) x