{-# LANGUAGE TypeApplications #-}
module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (
boolCheckC,
embed,
expansion,
splitExpansion,
horner,
isZeroC,
invertC,
) where
import Data.Foldable (foldlM)
import Data.Traversable (for)
import Numeric.Natural (Natural)
import Prelude hiding (Bool, Eq (..), negate, splitAt, (!!),
(*), (+), (-), (^))
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Prelude (splitAt, (!!))
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
import ZkFold.Symbolic.Data.Bool (Bool)
import ZkFold.Symbolic.Data.Conditional (Conditional (..))
import ZkFold.Symbolic.Data.Eq (Eq (..))
boolCheckC :: Arithmetic a => ArithmeticCircuit a -> ArithmeticCircuit a
boolCheckC :: forall a.
Arithmetic a =>
ArithmeticCircuit a -> ArithmeticCircuit a
boolCheckC ArithmeticCircuit a
r = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ do
i
i <- ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit ArithmeticCircuit a
r
ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* (i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one))
embed :: Arithmetic a => a -> ArithmeticCircuit a
embed :: forall a. Arithmetic a => a -> ArithmeticCircuit a
embed a
x = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ x -> (i -> x) -> x
forall a b. a -> b -> a
const (a -> x
forall a b. FromConstant a b => a -> b
fromConstant a
x)
expansion :: MonadBlueprint i a m => Natural -> i -> m [i]
expansion :: forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
expansion Natural
n i
k = do
[i]
bits <- Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
bitsOf Natural
n i
k
i
k' <- [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner [i]
bits
ClosedPoly i a -> m ()
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m ()
constraint (\i -> x
x -> i -> x
x i
k x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
k')
[i] -> m [i]
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [i]
bits
splitExpansion :: MonadBlueprint i a m => Natural -> Natural -> i -> m (i, i)
splitExpansion :: forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> Natural -> i -> m (i, i)
splitExpansion Natural
n1 Natural
n2 i
k = do
[i]
bits <- Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
bitsOf (Natural
n1 Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
n2) i
k
let ([i]
lo, [i]
hi) = Natural -> [i] -> ([i], [i])
forall a. Natural -> [a] -> ([a], [a])
splitAt Natural
n1 [i]
bits
i
l <- [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner [i]
lo
i
h <- [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner [i]
hi
ClosedPoly i a -> m ()
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m ()
constraint (\i -> x
x -> i -> x
x i
k x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
l x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- Natural -> x -> x
forall b a. Scale b a => b -> a -> a
scale (Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
n1 :: Natural) (i -> x
x i
h))
(i, i) -> m (i, i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (i
l, i
h)
bitsOf :: MonadBlueprint i a m => Natural -> i -> m [i]
bitsOf :: forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
bitsOf Natural
n i
k = [Natural] -> (Natural -> m i) -> m [i]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Natural
0 .. Natural
n Natural -> Natural -> Natural
-! Natural
1] ((Natural -> m i) -> m [i]) -> (Natural -> m i) -> m [i]
forall a b. (a -> b) -> a -> b
$ \Natural
j ->
NewConstraint i a -> Witness i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
NewConstraint i a -> Witness i a -> m i
newConstrained (\i -> x
x i
i -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* (i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one)) (([x] -> Natural -> x
forall a. [a] -> Natural -> a
!! Natural
j) ([x] -> x) -> ((i -> x) -> [x]) -> (i -> x) -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> [x]
forall b. (BinaryExpansion b, Finite b) => b -> [b]
repr (x -> [x]) -> ((i -> x) -> x) -> (i -> x) -> [x]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
k))
where
repr :: forall b . (BinaryExpansion b, Finite b) => b -> [b]
repr :: forall b. (BinaryExpansion b, Finite b) => b -> [b]
repr = Natural -> [b] -> [b]
forall a. BinaryExpansion a => Natural -> [a] -> [a]
padBits (forall a. KnownNat (NumberOfBits a) => Natural
numberOfBits @b) ([b] -> [b]) -> (b -> [b]) -> b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> [b]
forall a. BinaryExpansion a => a -> [a]
binaryExpansion
horner :: MonadBlueprint i a m => [i] -> m i
horner :: forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner [i]
xs = case [i] -> [i]
forall a. [a] -> [a]
reverse [i]
xs of
[] -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (x -> (i -> x) -> x
forall a b. a -> b -> a
const x
forall a. AdditiveMonoid a => a
zero)
(i
b : [i]
bs) -> (i -> i -> m i) -> i -> [i] -> m i
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (\i
a i
i -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
a x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
a)) i
b [i]
bs
isZeroC :: Arithmetic a => ArithmeticCircuit a -> ArithmeticCircuit a
isZeroC :: forall a.
Arithmetic a =>
ArithmeticCircuit a -> ArithmeticCircuit a
isZeroC ArithmeticCircuit a
r = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ (i, i) -> i
forall a b. (a, b) -> a
fst ((i, i) -> i) -> m (i, i) -> m i
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a -> m (i, i)
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m (i, i)
runInvert ArithmeticCircuit a
r
invertC :: Arithmetic a => ArithmeticCircuit a -> ArithmeticCircuit a
invertC :: forall a.
Arithmetic a =>
ArithmeticCircuit a -> ArithmeticCircuit a
invertC ArithmeticCircuit a
r = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ (i, i) -> i
forall a b. (a, b) -> b
snd ((i, i) -> i) -> m (i, i) -> m i
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a -> m (i, i)
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m (i, i)
runInvert ArithmeticCircuit a
r
runInvert :: MonadBlueprint i a m => ArithmeticCircuit a -> m (i, i)
runInvert :: forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m (i, i)
runInvert ArithmeticCircuit a
r = do
i
i <- ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit ArithmeticCircuit a
r
i
j <- NewConstraint i a -> Witness i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
NewConstraint i a -> Witness i a -> m i
newConstrained (\i -> x
x i
j -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
j) (x -> x
forall a. (Ring a, Eq (Bool a) a, Conditional (Bool a) a) => a -> a
isZero (x -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
i))
i
k <- NewConstraint i a -> Witness i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
NewConstraint i a -> Witness i a -> m i
newConstrained (\i -> x
x i
k -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
k x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
j x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one) (x -> x
forall a. Field a => a -> a
finv (x -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
i))
(i, i) -> m (i, i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (i
j, i
k)
where
isZero :: forall a . (Ring a, Eq (Bool a) a, Conditional (Bool a) a) => a -> a
isZero :: forall a. (Ring a, Eq (Bool a) a, Conditional (Bool a) a) => a -> a
isZero a
x = forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool a) a
forall a. AdditiveMonoid a => a
zero a
forall a. MultiplicativeMonoid a => a
one (a
x a -> a -> Bool a
forall b a. Eq b a => a -> a -> b
== a
forall a. AdditiveMonoid a => a
zero)