{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}
module ZkFold.Symbolic.Compiler.Arithmetizable (
Arithmetic,
Arithmetizable (..),
SomeArithmetizable (..),
SomeData (..),
SymbolicData (..),
) where
import Data.Typeable (Typeable)
import Numeric.Natural (Natural)
import Prelude hiding (Num (..), drop, length, product, splitAt,
sum, take, (!!), (^))
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
import ZkFold.Base.Data.Vector hiding (concat)
import ZkFold.Prelude (drop, length, splitAt, take)
import ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit)
class Arithmetic a => SymbolicData a x where
pieces :: x -> [ArithmeticCircuit a]
restore :: [ArithmeticCircuit a] -> x
typeSize :: Natural
data SomeData a where
SomeData :: (Typeable t, SymbolicData a t) => t -> SomeData a
instance Arithmetic a => SymbolicData a () where
pieces :: () -> [ArithmeticCircuit a]
pieces () = []
restore :: [ArithmeticCircuit a] -> ()
restore [] = ()
restore [ArithmeticCircuit a]
_ = [Char] -> ()
forall a. HasCallStack => [Char] -> a
error [Char]
"restore (): wrong number of arguments"
typeSize :: Natural
typeSize = Natural
0
instance (SymbolicData a x, SymbolicData a y) => SymbolicData a (x, y) where
pieces :: (x, y) -> [ArithmeticCircuit a]
pieces (x
a, y
b) = x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces x
a [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a] -> [a]
++ y -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces y
b
restore :: [ArithmeticCircuit a] -> (x, y)
restore [ArithmeticCircuit a]
rs
| [ArithmeticCircuit a] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [ArithmeticCircuit a]
rs Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
/= forall a x. SymbolicData a x => Natural
typeSize @a @(x, y) = [Char] -> (x, y)
forall a. HasCallStack => [Char] -> a
error [Char]
"restore: wrong number of arguments"
| Bool
otherwise = ([ArithmeticCircuit a] -> x
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsX, [ArithmeticCircuit a] -> y
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsY)
where ([ArithmeticCircuit a]
rsX, [ArithmeticCircuit a]
rsY) = Natural
-> [ArithmeticCircuit a]
-> ([ArithmeticCircuit a], [ArithmeticCircuit a])
forall a. Natural -> [a] -> ([a], [a])
splitAt (forall a x. SymbolicData a x => Natural
typeSize @a @x) [ArithmeticCircuit a]
rs
typeSize :: Natural
typeSize = forall a x. SymbolicData a x => Natural
typeSize @a @x Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a x. SymbolicData a x => Natural
typeSize @a @y
instance (SymbolicData a x, SymbolicData a y, SymbolicData a z) => SymbolicData a (x, y, z) where
pieces :: (x, y, z) -> [ArithmeticCircuit a]
pieces (x
a, y
b, z
c) = x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces x
a [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a] -> [a]
++ y -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces y
b [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a] -> [a]
++ z -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces z
c
restore :: [ArithmeticCircuit a] -> (x, y, z)
restore [ArithmeticCircuit a]
rs
| [ArithmeticCircuit a] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [ArithmeticCircuit a]
rs Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
/= forall a x. SymbolicData a x => Natural
typeSize @a @(x, y, z) = [Char] -> (x, y, z)
forall a. HasCallStack => [Char] -> a
error [Char]
"restore: wrong number of arguments"
| Bool
otherwise = ([ArithmeticCircuit a] -> x
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsX, [ArithmeticCircuit a] -> y
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsY, [ArithmeticCircuit a] -> z
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsZ)
where
([ArithmeticCircuit a]
rsX, [ArithmeticCircuit a]
rsYZ) = Natural
-> [ArithmeticCircuit a]
-> ([ArithmeticCircuit a], [ArithmeticCircuit a])
forall a. Natural -> [a] -> ([a], [a])
splitAt (forall a x. SymbolicData a x => Natural
typeSize @a @x) [ArithmeticCircuit a]
rs
([ArithmeticCircuit a]
rsY, [ArithmeticCircuit a]
rsZ) = Natural
-> [ArithmeticCircuit a]
-> ([ArithmeticCircuit a], [ArithmeticCircuit a])
forall a. Natural -> [a] -> ([a], [a])
splitAt (forall a x. SymbolicData a x => Natural
typeSize @a @y) [ArithmeticCircuit a]
rsYZ
typeSize :: Natural
typeSize = forall a x. SymbolicData a x => Natural
typeSize @a @x Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a x. SymbolicData a x => Natural
typeSize @a @y Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a x. SymbolicData a x => Natural
typeSize @a @z
instance (SymbolicData a x, KnownNat n) => SymbolicData a (Vector n x) where
pieces :: Vector n x -> [ArithmeticCircuit a]
pieces (Vector [x]
xs) = (x -> [ArithmeticCircuit a]) -> [x] -> [ArithmeticCircuit a]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces [x]
xs
restore :: [ArithmeticCircuit a] -> Vector n x
restore [ArithmeticCircuit a]
rs
| [ArithmeticCircuit a] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [ArithmeticCircuit a]
rs Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
/= forall a x. SymbolicData a x => Natural
typeSize @a @(Vector n x) = [Char] -> Vector n x
forall a. HasCallStack => [Char] -> a
error [Char]
"restore: wrong number of arguments"
| Bool
otherwise = [ArithmeticCircuit a] -> Natural -> x
f [ArithmeticCircuit a]
rs (Natural -> x) -> Vector n Natural -> Vector n x
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Natural] -> Vector n Natural
forall (size :: Natural) a. [a] -> Vector size a
Vector [Natural
0 .. forall (n :: Natural). KnownNat n => Natural
value @n Natural -> Natural -> Natural
-! Natural
1]
where
f :: [ArithmeticCircuit a] -> Natural -> x
f [ArithmeticCircuit a]
as = forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore @a @x ([ArithmeticCircuit a] -> x)
-> (Natural -> [ArithmeticCircuit a]) -> Natural -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. Natural -> [a] -> [a]
take (forall a x. SymbolicData a x => Natural
typeSize @a @x) ([ArithmeticCircuit a] -> [ArithmeticCircuit a])
-> (Natural -> [ArithmeticCircuit a])
-> Natural
-> [ArithmeticCircuit a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural -> [ArithmeticCircuit a] -> [ArithmeticCircuit a])
-> [ArithmeticCircuit a] -> Natural -> [ArithmeticCircuit a]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Natural -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. Natural -> [a] -> [a]
drop [ArithmeticCircuit a]
as (Natural -> [ArithmeticCircuit a])
-> (Natural -> Natural) -> Natural -> [ArithmeticCircuit a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((forall a x. SymbolicData a x => Natural
typeSize @a @x) Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
*)
typeSize :: Natural
typeSize = forall a x. SymbolicData a x => Natural
typeSize @a @x Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* forall (n :: Natural). KnownNat n => Natural
value @n
class Arithmetic a => Arithmetizable a x where
arithmetize :: x -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
inputSize :: Natural
outputSize :: Natural
data SomeArithmetizable a where
SomeArithmetizable :: (Typeable t, Arithmetizable a t) => t -> SomeArithmetizable a
instance {-# OVERLAPPABLE #-} SymbolicData a x => Arithmetizable a x where
arithmetize :: x -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
arithmetize x
x [] = x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces x
x
arithmetize x
_ [ArithmeticCircuit a]
_ = [Char] -> [ArithmeticCircuit a]
forall a. HasCallStack => [Char] -> a
error [Char]
"arithmetize: wrong number of inputs"
inputSize :: Natural
inputSize = Natural
0
outputSize :: Natural
outputSize = forall a x. SymbolicData a x => Natural
typeSize @a @x
instance (SymbolicData a x, Arithmetizable a f) => Arithmetizable a (x -> f) where
arithmetize :: (x -> f) -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
arithmetize x -> f
f [ArithmeticCircuit a]
is =
let ([ArithmeticCircuit a]
xs, [ArithmeticCircuit a]
os) = Natural
-> [ArithmeticCircuit a]
-> ([ArithmeticCircuit a], [ArithmeticCircuit a])
forall a. Natural -> [a] -> ([a], [a])
splitAt (forall a x. SymbolicData a x => Natural
typeSize @a @x) [ArithmeticCircuit a]
is
in f -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a x.
Arithmetizable a x =>
x -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
arithmetize (x -> f
f (x -> f) -> x -> f
forall a b. (a -> b) -> a -> b
$ [ArithmeticCircuit a] -> x
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
xs) [ArithmeticCircuit a]
os
inputSize :: Natural
inputSize = forall a x. SymbolicData a x => Natural
typeSize @a @x Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a x. Arithmetizable a x => Natural
inputSize @a @f
outputSize :: Natural
outputSize = forall a x. Arithmetizable a x => Natural
outputSize @a @f