{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeApplications    #-}

module ZkFold.Symbolic.Compiler (
    module ZkFold.Symbolic.Compiler.Arithmetizable,
    module ZkFold.Symbolic.Compiler.ArithmeticCircuit,
    compile,
    compileIO
) where

import           Data.Aeson                                                (ToJSON)
import           Data.Foldable                                             (fold)
import           Prelude                                                   (FilePath, IO, Show (..), map, putStrLn, ($),
                                                                            (++))

import           ZkFold.Prelude                                            (replicateA, writeFileJSON)
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit
import           ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
import           ZkFold.Symbolic.Compiler.Arithmetizable

{-
    ZkFold Symbolic compiler module dependency order:
    1. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Internal
    2. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Map
    3. ZkFold.Symbolic.Compiler.ArithmeticCircuit.MonadBlueprint
    4. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Combinators
    5. ZkFold.Symbolic.Compiler.Arithmetizable
    6. ZkFold.Symbolic.Compiler.ArithmeticCircuit.Instance
    7. ZkFold.Symbolic.Compiler.ArithmeticCircuit
    8. ZkFold.Symbolic.Compiler
-}

-- | Arithmetizes an argument by feeding an appropriate amount of inputs.
solder :: forall a f . Arithmetizable a f => f -> [ArithmeticCircuit a]
solder :: forall a f. Arithmetizable a f => f -> [ArithmeticCircuit a]
solder f
f = f -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a x.
Arithmetizable a x =>
x -> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
arithmetize f
f ([ArithmeticCircuit a] -> [ArithmeticCircuit a])
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a b. (a -> b) -> a -> b
$ (forall i (m :: Type -> Type). MonadBlueprint i a m => m [i])
-> [ArithmeticCircuit a]
forall a (f :: Type -> Type).
(Arithmetic a, Functor f) =>
(forall i (m :: Type -> Type). MonadBlueprint i a m => m (f i))
-> f (ArithmeticCircuit a)
circuits ((forall i (m :: Type -> Type). MonadBlueprint i a m => m [i])
 -> [ArithmeticCircuit a])
-> (forall i (m :: Type -> Type). MonadBlueprint i a m => m [i])
-> [ArithmeticCircuit a]
forall a b. (a -> b) -> a -> b
$ Natural -> m i -> m [i]
forall (f :: Type -> Type) a.
Applicative f =>
Natural -> f a -> f [a]
replicateA (forall a x. Arithmetizable a x => Natural
inputSize @a @f) m i
forall i a (m :: Type -> Type). MonadBlueprint i a m => m i
input

-- | Compiles function `f` into an arithmetic circuit.
compile :: forall a f y . (Arithmetizable a f, SymbolicData a y) => f -> y
compile :: forall a f y. (Arithmetizable a f, SymbolicData a y) => f -> y
compile f
f = forall a x. SymbolicData a x => [ArithmeticCircuit a] -> x
restore @a ((ArithmeticCircuit a -> ArithmeticCircuit a)
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a b. (a -> b) -> [a] -> [b]
map ArithmeticCircuit a -> ArithmeticCircuit a
forall a. ArithmeticCircuit a -> ArithmeticCircuit a
optimize ([ArithmeticCircuit a] -> [ArithmeticCircuit a])
-> [ArithmeticCircuit a] -> [ArithmeticCircuit a]
forall a b. (a -> b) -> a -> b
$ f -> [ArithmeticCircuit a]
forall a f. Arithmetizable a f => f -> [ArithmeticCircuit a]
solder f
f)

-- | Compiles a function `f` into an arithmetic circuit. Writes the result to a file.
compileIO :: forall a f . (ToJSON a, Arithmetizable a f) => FilePath -> f -> IO ()
compileIO :: forall a f.
(ToJSON a, Arithmetizable a f) =>
FilePath -> f -> IO ()
compileIO FilePath
scriptFile f
f = do
    let ac :: ArithmeticCircuit a
ac = ArithmeticCircuit a -> ArithmeticCircuit a
forall a. ArithmeticCircuit a -> ArithmeticCircuit a
optimize ([ArithmeticCircuit a] -> ArithmeticCircuit a
forall m. Monoid m => [m] -> m
forall (t :: Type -> Type) m. (Foldable t, Monoid m) => t m -> m
fold (f -> [ArithmeticCircuit a]
forall a f. Arithmetizable a f => f -> [ArithmeticCircuit a]
solder f
f)) :: ArithmeticCircuit a

    FilePath -> IO ()
putStrLn FilePath
"\nCompiling the script...\n"

    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Number of constraints: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Natural -> FilePath
forall a. Show a => a -> FilePath
show (ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acSizeN ArithmeticCircuit a
ac)
    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Number of variables: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ Natural -> FilePath
forall a. Show a => a -> FilePath
show (ArithmeticCircuit a -> Natural
forall a. ArithmeticCircuit a -> Natural
acSizeM ArithmeticCircuit a
ac)
    FilePath -> ArithmeticCircuit a -> IO ()
forall a. ToJSON a => FilePath -> a -> IO ()
writeFileJSON FilePath
scriptFile ArithmeticCircuit a
ac
    FilePath -> IO ()
putStrLn (FilePath -> IO ()) -> FilePath -> IO ()
forall a b. (a -> b) -> a -> b
$ FilePath
"Script saved: " FilePath -> FilePath -> FilePath
forall a. [a] -> [a] -> [a]
++ FilePath
scriptFile