{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE TypeApplications     #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Symbolic.Compiler.Arithmetizable (
        Arithmetic,
        Arithmetizable (..),
        SomeArithmetizable (..),
        SomeData (..),
        SymbolicData (..),
    ) where

import           Data.Typeable                                       (Typeable)
import           Numeric.Natural                                     (Natural)
import           Prelude                                             hiding (Num (..), drop, length, product, splitAt,
                                                                      sum, take, (!!), (^))

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Data.Vector                             hiding (concat)
import           ZkFold.Prelude                                      (drop, length, splitAt, take)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal (Arithmetic, ArithmeticCircuit)

-- | A class for Symbolic data types.
-- Type `a` is the finite field of the arithmetic circuit.
-- Type `x` represents the data type.
class Arithmetic a => SymbolicData a x where
    -- | Returns the circuits that make up `x`.
    pieces :: x -> [ArithmeticCircuit a]

    -- | Restores `x` from the circuits' outputs.
    restore :: [ArithmeticCircuit a] -> x

    -- | Returns the number of finite field elements needed to describe `x`.
    typeSize :: Natural

-- A wrapper for `SymbolicData` types.
data SomeData a where
    SomeData :: (Typeable t, SymbolicData a t) => t -> SomeData a

instance Arithmetic a => SymbolicData a () where
    pieces :: () -> [ArithmeticCircuit a]
pieces () = []

    restore :: [ArithmeticCircuit a] -> ()
restore [] = ()
    restore [ArithmeticCircuit a]
_  = [Char] -> ()
forall a. HasCallStack => [Char] -> a
error [Char]
"restore (): wrong number of arguments"

    typeSize :: Natural
typeSize = Natural
0

instance (SymbolicData a x, SymbolicData a y) => SymbolicData a (x, y) where
    pieces :: (x, y) -> [ArithmeticCircuit a]
pieces (x
a, y
b) = x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces x
a [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a] -> [a]
++ y -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces y
b

    restore :: [ArithmeticCircuit a] -> (x, y)
restore [ArithmeticCircuit a]
rs
        | [ArithmeticCircuit a] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [ArithmeticCircuit a]
rs Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
/= forall a x. SymbolicData a x => Natural
typeSize @a @(x, y) = [Char] -> (x, y)
forall a. HasCallStack => [Char] -> a
error [Char]
"restore: wrong number of arguments"
        | Bool
otherwise = ([ArithmeticCircuit a] -> x
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsX, [ArithmeticCircuit a] -> y
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsY)
        where ([ArithmeticCircuit a]
rsX, [ArithmeticCircuit a]
rsY) = Natural
-> [ArithmeticCircuit a]
-> ([ArithmeticCircuit a], [ArithmeticCircuit a])
forall a. Natural -> [a] -> ([a], [a])
splitAt (forall a x. SymbolicData a x => Natural
typeSize @a @x) [ArithmeticCircuit a]
rs

    typeSize :: Natural
typeSize = forall a x. SymbolicData a x => Natural
typeSize @a @x Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a x. SymbolicData a x => Natural
typeSize @a @y

instance (SymbolicData a x, SymbolicData a y, SymbolicData a z) => SymbolicData a (x, y, z) where
    pieces :: (x, y, z) -> [ArithmeticCircuit a]
pieces (x
a, y
b, z
c) = x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces x
a [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a] -> [a]
++ y -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces y
b [ArithmeticCircuit a]
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. [a] -> [a] -> [a]
++ z -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces z
c

    restore :: [ArithmeticCircuit a] -> (x, y, z)
restore [ArithmeticCircuit a]
rs
        | [ArithmeticCircuit a] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [ArithmeticCircuit a]
rs Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
/= forall a x. SymbolicData a x => Natural
typeSize @a @(x, y, z) = [Char] -> (x, y, z)
forall a. HasCallStack => [Char] -> a
error [Char]
"restore: wrong number of arguments"
        | Bool
otherwise = ([ArithmeticCircuit a] -> x
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsX, [ArithmeticCircuit a] -> y
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsY, [ArithmeticCircuit a] -> z
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
rsZ)
        where
            ([ArithmeticCircuit a]
rsX, [ArithmeticCircuit a]
rsYZ) = Natural
-> [ArithmeticCircuit a]
-> ([ArithmeticCircuit a], [ArithmeticCircuit a])
forall a. Natural -> [a] -> ([a], [a])
splitAt (forall a x. SymbolicData a x => Natural
typeSize @a @x) [ArithmeticCircuit a]
rs
            ([ArithmeticCircuit a]
rsY, [ArithmeticCircuit a]
rsZ)  = Natural
-> [ArithmeticCircuit a]
-> ([ArithmeticCircuit a], [ArithmeticCircuit a])
forall a. Natural -> [a] -> ([a], [a])
splitAt (forall a x. SymbolicData a x => Natural
typeSize @a @y) [ArithmeticCircuit a]
rsYZ

    typeSize :: Natural
typeSize = forall a x. SymbolicData a x => Natural
typeSize @a @x Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a x. SymbolicData a x => Natural
typeSize @a @y Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a x. SymbolicData a x => Natural
typeSize @a @z

instance (SymbolicData a x, KnownNat n) => SymbolicData a (Vector n x) where
    pieces :: Vector n x -> [ArithmeticCircuit a]
pieces (Vector [x]
xs) = (x -> [ArithmeticCircuit a]) -> [x] -> [ArithmeticCircuit a]
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> [b]) -> t a -> [b]
concatMap x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces [x]
xs

    restore :: [ArithmeticCircuit a] -> Vector n x
restore [ArithmeticCircuit a]
rs
        | [ArithmeticCircuit a] -> Natural
forall (t :: Type -> Type) a. Foldable t => t a -> Natural
length [ArithmeticCircuit a]
rs Natural -> Natural -> Bool
forall a. Eq a => a -> a -> Bool
/= forall a x. SymbolicData a x => Natural
typeSize @a @(Vector n x) = [Char] -> Vector n x
forall a. HasCallStack => [Char] -> a
error [Char]
"restore: wrong number of arguments"
        | Bool
otherwise = [ArithmeticCircuit a] -> Natural -> x
f [ArithmeticCircuit a]
rs (Natural -> x) -> Vector n Natural -> Vector n x
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Natural] -> Vector n Natural
forall (size :: Natural) a. [a] -> Vector size a
Vector [Natural
0 .. forall (n :: Natural). KnownNat n => Natural
value @n Natural -> Natural -> Natural
-! Natural
1]
        where
            f :: [ArithmeticCircuit a] -> Natural -> x
f [ArithmeticCircuit a]
as = forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore @a @x ([ArithmeticCircuit a] -> x)
-> (Natural -> [ArithmeticCircuit a]) -> Natural -> x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. Natural -> [a] -> [a]
take (forall a x. SymbolicData a x => Natural
typeSize @a @x) ([ArithmeticCircuit a] -> [ArithmeticCircuit a])
-> (Natural -> [ArithmeticCircuit a])
-> Natural
-> [ArithmeticCircuit a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Natural -> [ArithmeticCircuit a] -> [ArithmeticCircuit a])
-> [ArithmeticCircuit a] -> Natural -> [ArithmeticCircuit a]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Natural -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a. Natural -> [a] -> [a]
drop [ArithmeticCircuit a]
as (Natural -> [ArithmeticCircuit a])
-> (Natural -> Natural) -> Natural -> [ArithmeticCircuit a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((forall a x. SymbolicData a x => Natural
typeSize @a @x) Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
*)

    typeSize :: Natural
typeSize = forall a x. SymbolicData a x => Natural
typeSize @a @x Natural -> Natural -> Natural
forall a. MultiplicativeSemigroup a => a -> a -> a
* forall (n :: Natural). KnownNat n => Natural
value @n

-- | A class for arithmetizable types, that is, a class of types whose
-- computations can be represented by arithmetic circuits.
-- Type `a` is the finite field of the arithmetic circuit.
-- Type `x` represents the arithmetizable type.
class Arithmetic a => Arithmetizable a x where
    -- | Given a list of circuits computing inputs, return a list of circuits
    -- computing the result of `x`.
    arithmetize :: x -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]

    -- | The number of finite field elements needed to describe an input of `x`.
    inputSize :: Natural

    -- | The number of finite field elements needed to describe the result of `x`.
    outputSize :: Natural

-- A wrapper for `Arithmetizable` types.
data SomeArithmetizable a where
    SomeArithmetizable :: (Typeable t, Arithmetizable a t) => t -> SomeArithmetizable a

instance {-# OVERLAPPABLE #-} SymbolicData a x => Arithmetizable a x where
    arithmetize :: x -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
arithmetize x
x [] = x -> [ArithmeticCircuit a]
forall a x. SymbolicData a x => x -> [ArithmeticCircuit a]
pieces x
x
    arithmetize x
_ [ArithmeticCircuit a]
_  = [Char] -> [ArithmeticCircuit a]
forall a. HasCallStack => [Char] -> a
error [Char]
"arithmetize: wrong number of inputs"

    inputSize :: Natural
inputSize = Natural
0

    outputSize :: Natural
outputSize = forall a x. SymbolicData a x => Natural
typeSize @a @x

instance (SymbolicData a x, Arithmetizable a f) => Arithmetizable a (x -> f) where
    arithmetize :: (x -> f) -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
arithmetize x -> f
f [ArithmeticCircuit a]
is =
        let ([ArithmeticCircuit a]
xs, [ArithmeticCircuit a]
os) = Natural
-> [ArithmeticCircuit a]
-> ([ArithmeticCircuit a], [ArithmeticCircuit a])
forall a. Natural -> [a] -> ([a], [a])
splitAt (forall a x. SymbolicData a x => Natural
typeSize @a @x) [ArithmeticCircuit a]
is
         in f -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a x.
Arithmetizable a x =>
x -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
arithmetize (x -> f
f (x -> f) -> x -> f
forall a b. (a -> b) -> a -> b
$ [ArithmeticCircuit a] -> x
forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore [ArithmeticCircuit a]
xs) [ArithmeticCircuit a]
os

    inputSize :: Natural
inputSize = forall a x. SymbolicData a x => Natural
typeSize @a @x Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ forall a x. Arithmetizable a x => Natural
inputSize @a @f

    outputSize :: Natural
outputSize = forall a x. Arithmetizable a x => Natural
outputSize @a @f