{-# LANGUAGE TypeApplications #-}

module ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators (
    boolCheckC,
    embed,
    expansion,
    splitExpansion,
    horner,
    isZeroC,
    invertC,
) where

import           Data.Foldable                                             (foldlM)
import           Data.Traversable                                          (for)
import           Numeric.Natural                                           (Natural)
import           Prelude                                                   hiding (Bool, Eq (..), negate, splitAt, (!!),
                                                                            (*), (+), (-), (^))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Prelude                                            (splitAt, (!!))
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal       (Arithmetic, ArithmeticCircuit)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
import           ZkFold.Symbolic.Data.Bool                                 (Bool)
import           ZkFold.Symbolic.Data.Conditional                          (Conditional (..))
import           ZkFold.Symbolic.Data.Eq                                   (Eq (..))

boolCheckC :: Arithmetic a => ArithmeticCircuit a -> ArithmeticCircuit a
-- ^ @boolCheckC r@ computes @r (r - 1)@ in one PLONK constraint.
boolCheckC :: forall a.
Arithmetic a =>
ArithmeticCircuit a -> ArithmeticCircuit a
boolCheckC ArithmeticCircuit a
r = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
 -> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ do
    i
i <- ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit ArithmeticCircuit a
r
    ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* (i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one))

embed :: Arithmetic a => a -> ArithmeticCircuit a
embed :: forall a. Arithmetic a => a -> ArithmeticCircuit a
embed a
x = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
 -> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (ClosedPoly i a -> m i) -> ClosedPoly i a -> m i
forall a b. (a -> b) -> a -> b
$ x -> (i -> x) -> x
forall a b. a -> b -> a
const (a -> x
forall a b. FromConstant a b => a -> b
fromConstant a
x)

expansion :: MonadBlueprint i a m => Natural -> i -> m [i]
-- ^ @expansion n k@ computes a binary expansion of @k@ if it fits in @n@ bits.
expansion :: forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
expansion Natural
n i
k = do
    [i]
bits <- Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
bitsOf Natural
n i
k
    i
k' <- [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner [i]
bits
    ClosedPoly i a -> m ()
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m ()
constraint (\i -> x
x -> i -> x
x i
k x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
k')
    [i] -> m [i]
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return [i]
bits

splitExpansion :: MonadBlueprint i a m => Natural -> Natural -> i -> m (i, i)
-- ^ @splitExpansion n1 n2 k@ computes two values @(l, h)@ such that
-- @k = 2^n1 h + l@, @l@ fits in @n1@ bits and @h@ fits in n2 bits (if such
-- values exist).
splitExpansion :: forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> Natural -> i -> m (i, i)
splitExpansion Natural
n1 Natural
n2 i
k = do
    [i]
bits <- Natural -> i -> m [i]
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
bitsOf (Natural
n1 Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
n2) i
k
    let ([i]
lo, [i]
hi) = Natural -> [i] -> ([i], [i])
forall a. Natural -> [a] -> ([a], [a])
splitAt Natural
n1 [i]
bits
    i
l <- [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner [i]
lo
    i
h <- [i] -> m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner [i]
hi
    ClosedPoly i a -> m ()
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m ()
constraint (\i -> x
x -> i -> x
x i
k x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- i -> x
x i
l x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- Natural -> x -> x
forall b a. Scale b a => b -> a -> a
scale (Natural
2 Natural -> Natural -> Natural
forall a b. Exponent a b => a -> b -> a
^ Natural
n1 :: Natural) (i -> x
x i
h))
    (i, i) -> m (i, i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (i
l, i
h)

bitsOf :: MonadBlueprint i a m => Natural -> i -> m [i]
-- ^ @bitsOf n k@ creates @n@ bits and sets their witnesses equal to @n@ smaller
-- bits of @k@.
bitsOf :: forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
Natural -> i -> m [i]
bitsOf Natural
n i
k = [Natural] -> (Natural -> m i) -> m [i]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [Natural
0 .. Natural
n Natural -> Natural -> Natural
-! Natural
1] ((Natural -> m i) -> m [i]) -> (Natural -> m i) -> m [i]
forall a b. (a -> b) -> a -> b
$ \Natural
j ->
    NewConstraint i a -> Witness i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
NewConstraint i a -> Witness i a -> m i
newConstrained (\i -> x
x i
i -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* (i -> x
x i
i x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one)) (([x] -> Natural -> x
forall a. [a] -> Natural -> a
!! Natural
j) ([x] -> x) -> ((i -> x) -> [x]) -> (i -> x) -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> [x]
forall b. (BinaryExpansion b, Finite b) => b -> [b]
repr (x -> [x]) -> ((i -> x) -> x) -> (i -> x) -> [x]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
k))
    where
        repr :: forall b . (BinaryExpansion b, Finite b) => b -> [b]
        repr :: forall b. (BinaryExpansion b, Finite b) => b -> [b]
repr = Natural -> [b] -> [b]
forall a. BinaryExpansion a => Natural -> [a] -> [a]
padBits (forall a. KnownNat (NumberOfBits a) => Natural
numberOfBits @b) ([b] -> [b]) -> (b -> [b]) -> b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> [b]
forall a. BinaryExpansion a => a -> [a]
binaryExpansion

horner :: MonadBlueprint i a m => [i] -> m i
-- ^ @horner [b0,...,bn]@ computes the sum @b0 + 2 b1 + ... + 2^n bn@ using
-- Horner's scheme.
horner :: forall i a (m :: Type -> Type). MonadBlueprint i a m => [i] -> m i
horner [i]
xs = case [i] -> [i]
forall a. [a] -> [a]
reverse [i]
xs of
    []       -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (x -> (i -> x) -> x
forall a b. a -> b -> a
const x
forall a. AdditiveMonoid a => a
zero)
    (i
b : [i]
bs) -> (i -> i -> m i) -> i -> [i] -> m i
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (\i
a i
i -> ClosedPoly i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ClosedPoly i a -> m i
newAssigned (\i -> x
x -> i -> x
x i
i x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
a x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
a)) i
b [i]
bs

isZeroC :: Arithmetic a => ArithmeticCircuit a -> ArithmeticCircuit a
isZeroC :: forall a.
Arithmetic a =>
ArithmeticCircuit a -> ArithmeticCircuit a
isZeroC ArithmeticCircuit a
r = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
 -> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ (i, i) -> i
forall a b. (a, b) -> a
fst ((i, i) -> i) -> m (i, i) -> m i
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a -> m (i, i)
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m (i, i)
runInvert ArithmeticCircuit a
r

invertC :: Arithmetic a => ArithmeticCircuit a -> ArithmeticCircuit a
invertC :: forall a.
Arithmetic a =>
ArithmeticCircuit a -> ArithmeticCircuit a
invertC ArithmeticCircuit a
r = (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a.
Arithmetic a =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
circuit ((forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
 -> ArithmeticCircuit a)
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m i)
-> ArithmeticCircuit a
forall a b. (a -> b) -> a -> b
$ (i, i) -> i
forall a b. (a, b) -> b
snd ((i, i) -> i) -> m (i, i) -> m i
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> ArithmeticCircuit a -> m (i, i)
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m (i, i)
runInvert ArithmeticCircuit a
r

runInvert :: MonadBlueprint i a m => ArithmeticCircuit a -> m (i, i)
runInvert :: forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m (i, i)
runInvert ArithmeticCircuit a
r = do
    i
i <- ArithmeticCircuit a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
ArithmeticCircuit a -> m i
runCircuit ArithmeticCircuit a
r
    i
j <- NewConstraint i a -> Witness i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
NewConstraint i a -> Witness i a -> m i
newConstrained (\i -> x
x i
j -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
j) (x -> x
forall a. (Ring a, Eq (Bool a) a, Conditional (Bool a) a) => a -> a
isZero (x -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
i))
    i
k <- NewConstraint i a -> Witness i a -> m i
forall i a (m :: Type -> Type).
MonadBlueprint i a m =>
NewConstraint i a -> Witness i a -> m i
newConstrained (\i -> x
x i
k -> i -> x
x i
i x -> x -> x
forall a. MultiplicativeSemigroup a => a -> a -> a
* i -> x
x i
k x -> x -> x
forall a. AdditiveSemigroup a => a -> a -> a
+ i -> x
x i
j x -> x -> x
forall a. AdditiveGroup a => a -> a -> a
- x
forall a. MultiplicativeMonoid a => a
one) (x -> x
forall a. Field a => a -> a
finv (x -> x) -> ((i -> x) -> x) -> (i -> x) -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((i -> x) -> i -> x
forall a b. (a -> b) -> a -> b
$ i
i))
    (i, i) -> m (i, i)
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (i
j, i
k)
    where
      isZero :: forall a . (Ring a, Eq (Bool a) a, Conditional (Bool a) a) => a -> a
      isZero :: forall a. (Ring a, Eq (Bool a) a, Conditional (Bool a) a) => a -> a
isZero a
x = forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool a) a
forall a. AdditiveMonoid a => a
zero a
forall a. MultiplicativeMonoid a => a
one (a
x a -> a -> Bool a
forall b a. Eq b a => a -> a -> b
== a
forall a. AdditiveMonoid a => a
zero)