{-# OPTIONS -fno-warn-orphans #-}
{-# LANGUAGE DeriveAnyClass, DeriveFoldable, DeriveFunctor, DeriveGeneric,
FlexibleInstances, ParallelListComp, RecordWildCards,
ScopedTypeVariables, TupleSections #-}
module QAP
( QapSet(..)
, QAP(..)
, updateAtWire
, lookupAtWire
, cnstInpQapSet
, sumQapSet
, sumQapSetCnstInp
, sumQapSetMidOut
, foldQapSet
, combineWithDefaults
, combineInputsWithDefaults
, combineNonInputsWithDefaults
, verifyAssignment
, verificationWitness
, verificationWitnessZk
, gateToQAP
, gateToGenQAP
, qapSetToMap
, initialQapSet
, generateAssignmentGate
, generateAssignment
, addMissingZeroes
, arithCircuitToGenQAP
, arithCircuitToQAP
, arithCircuitToQAPFFT
, createPolynomials
, createPolynomialsFFT
) where
import Protolude hiding (quot, quotRem)
import Data.Aeson (FromJSON, ToJSON)
import Data.Aeson.Types
import Data.Foldable (foldr1)
import Data.Map (Map, fromList, mapKeys)
import qualified Data.Map as Map
import qualified Data.Map.Merge.Lazy as Merge
import Data.Euclidean (Euclidean(..))
import Data.Field (Field)
import Data.Field.Galois (GaloisField, Prime, pow)
import Data.Poly
import qualified Data.Vector as V
import Text.PrettyPrint.Leijen.Text (Pretty(..), enclose, indent,
lbracket, rbracket, text, vcat,
(<+>))
import Circuit.Affine (affineCircuitToAffineMap)
import Circuit.Arithmetic (ArithCircuit(..), Gate(..), Wire(..),
evalArithCircuit, evalGate)
import qualified FFT
data QapSet f = QapSet
{ qapSetConstant :: f
, qapSetInput :: Map Int f
, qapSetIntermediate :: Map Int f
, qapSetOutput :: Map Int f
} deriving (Show, Eq, Functor, Foldable, Generic, NFData, ToJSON, FromJSON)
data QAP f = QAP
{ qapInputsLeft :: QapSet (VPoly f)
, qapInputsRight :: QapSet (VPoly f)
, qapOutputs :: QapSet (VPoly f)
, qapTarget :: VPoly f
} deriving (Show, Eq, Generic, NFData, ToJSON, FromJSON)
instance (ToJSON f, Generic f) => ToJSON (VPoly f) where
toJSON = toJSON . unPoly
instance (FromJSON f, Generic f, Eq f, Num f) => FromJSON (VPoly f) where
parseJSON v = toPoly <$> parseJSON v
instance ToJSON (Prime n)
instance FromJSON (Prime n)
data GenQAP p f = GenQAP
{ genQapInputsLeft :: QapSet (p f)
, genQapInputsRight :: QapSet (p f)
, genQapOutputs :: QapSet (p f)
, genQapTarget :: p f
} deriving (Show, Eq, Generic, NFData, ToJSON, FromJSON)
sequenceQapSet :: [QapSet f] -> QapSet [f]
sequenceQapSet qapSets = QapSet constants inputs mids outputs
where
constants = map qapSetConstant qapSets
inputs = Map.unionsWith (<>) . fmap (fmap pure) $ map qapSetInput qapSets
mids = Map.unionsWith (<>) . fmap (fmap pure) $ map qapSetIntermediate qapSets
outputs = Map.unionsWith (<>) . fmap (fmap pure) $ map qapSetOutput qapSets
constantQapSet :: g -> QapSet g
constantQapSet g = QapSet
{ qapSetConstant = g
, qapSetInput = Map.empty
, qapSetIntermediate = Map.empty
, qapSetOutput = Map.empty
}
cnstInpQapSet :: g -> Map Int g -> QapSet g
cnstInpQapSet g inp = QapSet
{ qapSetConstant = g
, qapSetInput = inp
, qapSetIntermediate = Map.empty
, qapSetOutput = Map.empty
}
sumQapSet :: Monoid g => QapSet g -> g
sumQapSet = fold
sumQapSetCnstInp :: Monoid g => QapSet g -> g
sumQapSetCnstInp (QapSet cnst inp _ _)
= cnst <> fold inp
sumQapSetMidOut :: Monoid g => QapSet g -> g
sumQapSetMidOut (QapSet _ _ mid out)
= fold mid <> fold out
instance Pretty (Ratio Integer) where
pretty = text . show
instance Pretty f => Pretty (QapSet f) where
pretty (QapSet constant inps mids outps)
= vcat
[ text "constant:" <+> pretty constant
, text "inputs:"
, indent 2 $ ppMap inps
, text "outputs:"
, indent 2 $ ppMap outps
, text "intermediates:"
, indent 2 $ ppMap mids
]
where
ppMap
= vcat
. map (\(ix, x) -> enclose lbracket rbracket (pretty ix) <+> pretty x)
. Map.toList
combineWithDefaults
:: (a -> b -> c)
-> a
-> b
-> QapSet a
-> QapSet b
-> QapSet c
combineWithDefaults f defaultA defaultB (QapSet cA inpA midA outpA) (QapSet cB inpB midB outpB)
= QapSet
{ qapSetConstant = f cA cB
, qapSetInput = combineMaps inpA inpB
, qapSetIntermediate = combineMaps midA midB
, qapSetOutput = combineMaps outpA outpB
}
where
combineMaps = Merge.merge missingRight missingLeft matching
missingLeft = Merge.mapMissing $ const $ f defaultA
missingRight = Merge.mapMissing $ const $ flip f defaultB
matching = Merge.zipWithMatched $ const f
combineInputsWithDefaults
:: (a -> b -> c)
-> a
-> b
-> QapSet a
-> QapSet b
-> QapSet c
combineInputsWithDefaults f defaultA defaultB (QapSet cA inpA _ _) (QapSet cB inpB _ _)
= QapSet
{ qapSetConstant = f cA cB
, qapSetInput = combineMaps inpA inpB
, qapSetIntermediate = mempty
, qapSetOutput = mempty
}
where
combineMaps = Merge.merge missingRight missingLeft matching
missingLeft = Merge.mapMissing $ const $ f defaultA
missingRight = Merge.mapMissing $ const $ flip f defaultB
matching = Merge.zipWithMatched $ const f
combineNonInputsWithDefaults
:: (a -> b -> c)
-> a
-> b
-> c
-> QapSet a
-> QapSet b
-> QapSet c
combineNonInputsWithDefaults f defaultA defaultB defaultC (QapSet _ _ midA outpA) (QapSet _ _ midB outpB)
= QapSet
{ qapSetConstant = defaultC
, qapSetInput = mempty
, qapSetIntermediate = combineMaps midA midB
, qapSetOutput = combineMaps outpA outpB
}
where
combineMaps = Merge.merge missingRight missingLeft matching
missingLeft = Merge.mapMissing $ const $ f defaultA
missingRight = Merge.mapMissing $ const $ flip f defaultB
matching = Merge.zipWithMatched $ const f
foldQapSet
:: (a -> a -> a)
-> QapSet a
-> a
foldQapSet = foldr1
createMapGenQap :: Ord k => [GenQAP ((,) k) k] -> GenQAP (Map k) k
createMapGenQap genQaps = GenQAP inpLefts inpRights outputs targets
where
inpLefts = fmap Map.fromList . sequenceQapSet . map genQapInputsLeft $ genQaps
inpRights = fmap Map.fromList . sequenceQapSet . map genQapInputsRight $ genQaps
outputs = fmap Map.fromList . sequenceQapSet . map genQapOutputs $ genQaps
targets = Map.fromList . map genQapTarget $ genQaps
instance (Eq f, Num f, Pretty f, Show f) => Pretty (QAP f) where
pretty (QAP inpsLeft inpsRight outps target)
= vcat
[ text "QAP:"
, text "inputs left:"
, indent 2 . text . show $ inpsLeft
, text "inputs right:"
, indent 2 . text . show $ inpsRight
, text "outputs:"
, indent 2 . text . show $ outps
, text "target: " <> text (show target)
]
instance (Pretty f, Pretty (p f)) => Pretty (GenQAP p f) where
pretty (GenQAP inpsLeft inpsRight outps target)
= vcat
[ text "QAP:"
, text "inputs left:"
, indent 2 $ pretty inpsLeft
, text "inputs right:"
, indent 2 $ pretty inpsRight
, text "outputs:"
, indent 2 $ pretty outps
, text "target: " <> pretty target
]
instance Functor p => Functor (GenQAP p) where
fmap f (GenQAP inpLeft inpRight outp target)
= GenQAP (fmap (fmap f) inpLeft)
(fmap (fmap f) inpRight)
(fmap (fmap f) outp)
(fmap f target)
verifyAssignment
:: (Eq f, Field f, Num f)
=> QAP f
-> QapSet f
-> Bool
verifyAssignment qap assignment = isJust $ verificationWitness qap assignment
verificationWitness
:: forall f . (Eq f, Field f, Num f)
=> QAP f
-> QapSet f
-> Maybe (VPoly f)
verificationWitness = verificationWitnessZk 0 0 0
verificationWitnessZk
:: (Eq f, Field f, Num f)
=> f
-> f
-> f
-> QAP f
-> QapSet f
-> Maybe (VPoly f)
verificationWitnessZk delta1 delta2 delta3 QAP {..} assignment
= if remainder == 0
then Just quotient
else Nothing
where
scaleWithAssignment x = combineWithDefaults (\a b -> monomial 0 b * a) 0 0 x assignment
leftInputPoly
= (monomial 0 delta1 * qapTarget)
+ sumQap (scaleWithAssignment qapInputsLeft)
rightInputPoly
= (monomial 0 delta2 * qapTarget)
+ sumQap (scaleWithAssignment qapInputsRight)
outputPoly
= (monomial 0 delta3 * qapTarget)
+ sumQap (scaleWithAssignment qapOutputs)
sumQap = foldQapSet (+)
inputOutputPoly
= (leftInputPoly * rightInputPoly) - outputPoly
(quotient, remainder) = quotRem inputOutputPoly qapTarget
lookupAtWire :: Wire -> QapSet a -> Maybe a
lookupAtWire (InputWire ix) QapSet { qapSetInput = inps }
= Map.lookup ix inps
lookupAtWire (IntermediateWire ix) QapSet { qapSetIntermediate = mids }
= Map.lookup ix mids
lookupAtWire (OutputWire ix) QapSet { qapSetOutput = outps }
= Map.lookup ix outps
updateAtWire :: Wire -> a -> QapSet a -> QapSet a
updateAtWire (InputWire ix) a qs@QapSet { qapSetInput = inps }
= qs { qapSetInput = Map.insert ix a inps }
updateAtWire (IntermediateWire ix) a qs@QapSet { qapSetIntermediate = mids }
= qs { qapSetIntermediate = Map.insert ix a mids }
updateAtWire (OutputWire ix) a qs@QapSet { qapSetOutput = outps }
= qs { qapSetOutput = Map.insert ix a outps }
updateAtWires :: [(Wire, a)] -> QapSet a -> QapSet a
updateAtWires wireVals vars
= foldl' (\rest (wire, val) -> updateAtWire wire val rest) vars wireVals
gateToQAP
:: GaloisField k
=> (Int -> k)
-> [k]
-> Gate Wire k
-> QAP k
gateToQAP primRoots roots
= createPolynomialsFFT primRoots . addMissingZeroes roots . createMapGenQap . gateToGenQAP roots
gateToGenQAP
:: (GaloisField k)
=> [k]
-> Gate Wire k
-> [GenQAP ((,) k) k]
gateToGenQAP [root] (Mul l r wire)
= pure
. addOutputVals
. addInputVals
$ GenQAP
{ genQapInputsLeft = constantQapSet (root, leftInputConst)
, genQapInputsRight = constantQapSet (root, rightInputConst)
, genQapOutputs = constantQapSet (root, 0)
, genQapTarget = (root, 0)
}
where
(leftInputConst, leftInputVector) = affineCircuitToAffineMap l
(rightInputConst, rightInputVector) = affineCircuitToAffineMap r
addInputVals (GenQAP left right out t)
= GenQAP (Map.foldrWithKey updateAtWire left $ fmap (root,) leftInputVector)
(Map.foldrWithKey updateAtWire right $ fmap (root,) rightInputVector)
out
t
addOutputVals (GenQAP left right out t)
= GenQAP left
right
(updateAtWire wire (root, 1) out)
t
gateToGenQAP [root0,root1] (Equal i m outputWire)
= [qap0, qap1]
where
qap0 = GenQAP
{ genQapInputsLeft
= updateAtWires [ (i, (root0, 1))
, (m, (root0, 0))
, (outputWire, (root0, 0))
]
$ constantQapSet (root0, 0)
, genQapInputsRight
= updateAtWires [ (i, (root0, 0))
, (m, (root0, 1))
, (outputWire, (root0, 0))
]
$ constantQapSet (root0, 0)
, genQapOutputs
= updateAtWires [ (i, (root0, 0))
, (m, (root0, 0))
, (outputWire, (root0, 1))
]
$ constantQapSet (root0, 0)
, genQapTarget
= (root0, 0)
}
qap1 = GenQAP
{ genQapInputsLeft
= updateAtWires [ (i, (root1, 0))
, (m, (root1, 0))
, (outputWire, (root1, -1))
]
$ constantQapSet (root1, 1)
, genQapInputsRight
= updateAtWires [ (i, (root1, 1))
, (m, (root1, 0))
, (outputWire, (root1, 0))
]
$ constantQapSet (root1, 0)
, genQapOutputs
= updateAtWires [ (i, (root1, 0))
, (m, (root1, 0))
, (outputWire, (root1, 0))
]
$ constantQapSet (root1, 0)
, genQapTarget
= (root1, 0)
}
gateToGenQAP (root:roots) (Split inp outputs)
= if length roots /= length outputs
then panic "gateToGenQAP: wrong number of roots supplied"
else qap0:zipWith qaps roots outputs
where
qap0 = GenQAP
{ genQapInputsLeft
= updateAtWires ((inp, (root, 0)):zipWith (\output i -> (output, (root, 2 `pow` i))) outputs [0 :: Integer ..])
$ constantQapSet (root, 0)
, genQapInputsRight
= updateAtWires [(inp, (root, 0))]
$ constantQapSet (root, 1)
, genQapOutputs
= updateAtWires [(inp, (root, 1))]
$ constantQapSet (root, 0)
, genQapTarget
= (root, 0)
}
qaps r outp = GenQAP
{ genQapInputsLeft
= updateAtWire outp (r, 1)
$ constantQapSet (r, 0)
, genQapInputsRight
= updateAtWire outp (r, -1)
$ constantQapSet (r, 1)
, genQapOutputs
= updateAtWire outp (r, 0)
$ constantQapSet (r, 0)
, genQapTarget
= (r, 0)
}
gateToGenQAP _ _ = panic "gateToGenQAP: wrong number of roots supplied"
createPolynomials :: forall k. (GaloisField k) => GenQAP (Map k) k -> QAP k
createPolynomials (GenQAP inpLeft inpRight outp targetRoots)
= QAP
{ qapInputsLeft = fmap (lagrangeInterpolate . Map.toList) inpLeft
, qapInputsRight = fmap (lagrangeInterpolate . Map.toList) inpRight
, qapOutputs = fmap (lagrangeInterpolate . Map.toList) outp
, qapTarget = foldl' (*) (monomial 0 1) . map ((\root -> toPoly $ V.fromList [-root, 1]) . fst) . Map.toList $ targetRoots
}
where
lagrangeInterpolate :: [(k, k)] -> VPoly k
lagrangeInterpolate xys = sum
[ scale 0 f (roots `quot` root x)
| f <- zipWith (/) ys phis
| x <- xs
]
where
xs, ys :: [k]
(xs,ys) = foldr (\(a, b) ~(as,bs) -> (a:as,b:bs)) ([],[]) xys
phis :: [k]
phis = map (eval (deriv roots)) xs
roots :: VPoly k
roots = foldl' (\acc xi -> acc * root xi) 1 xs
root xi = toPoly . V.fromList $ [-xi, 1]
createPolynomialsFFT
:: GaloisField k
=> (Int -> k)
-> GenQAP (Map k) k
-> QAP k
createPolynomialsFFT primRoots (GenQAP inpLeft inpRight outp targetRoots)
= QAP
{ qapInputsLeft = fmap (FFT.interpolate primRoots . Map.elems) inpLeft
, qapInputsRight = fmap (FFT.interpolate primRoots . Map.elems) inpRight
, qapOutputs = fmap (FFT.interpolate primRoots . Map.elems) outp
, qapTarget = FFT.fftTargetPoly primRoots (Map.size targetRoots)
}
arithCircuitToGenQAP
:: GaloisField k
=> [[k]]
-> ArithCircuit k
-> GenQAP (Map k) k
arithCircuitToGenQAP rootsPerGate (ArithCircuit gates)
= addMissingZeroes (concat rootsPerGate)
. createMapGenQap
. concat
$ zipWith gateToGenQAP rootsPerGate gates
arithCircuitToQAP
:: GaloisField k
=> [[k]]
-> ArithCircuit k
-> QAP k
arithCircuitToQAP roots circuit =
createPolynomials
$ arithCircuitToGenQAP roots circuit
arithCircuitToQAPFFT
:: GaloisField k
=> (Int -> k)
-> [[k]]
-> ArithCircuit k
-> QAP k
arithCircuitToQAPFFT primRoots roots circuit =
createPolynomialsFFT primRoots
$ arithCircuitToGenQAP roots circuit
addMissingZeroes
:: forall f . (Ord f, Num f)
=> [f] -> GenQAP (Map f) f -> GenQAP (Map f) f
addMissingZeroes allRoots (GenQAP inpLeft inpRight outp t)
= GenQAP (fmap (`Map.union` allZeroes) inpLeft)
(fmap (`Map.union` allZeroes) inpRight)
(fmap (`Map.union` allZeroes) outp)
(t `Map.union` allZeroes)
where
allZeroes :: Map f f
allZeroes = Map.fromList . map (,0) $ allRoots
generateAssignmentGate
:: (Bits f, Fractional f)
=> Gate Wire f
-> Map Int f
-> QapSet f
generateAssignmentGate program inps
= evalGate
lookupAtWire
updateAtWire
(initialQapSet inps)
program
initialQapSet
:: Num f
=> Map Int f
-> QapSet f
initialQapSet inputs = QapSet 1 inputs Map.empty Map.empty
generateAssignment
:: forall f . (Bits f, Fractional f)
=> ArithCircuit f
-> Map Int f
-> QapSet f
generateAssignment circuit inputs
= evalArithCircuit lookupAtWire updateAtWire circuit $ initialQapSet inputs
qapSetToMap :: QapSet g -> Map Int g
qapSetToMap QapSet{..}
= fromList [(0, qapSetConstant)]
<> mapKeys ((+) 1) qapSetInput
<> mapKeys ((+) (1 + numOfInputs)) qapSetIntermediate
<> mapKeys ((+) (1 + numOfInputs + numOfInterms)) qapSetOutput
where
numOfInputs = maxKey qapSetInput
numOfInterms = maxKey qapSetIntermediate
maxKey :: Map Int a -> Int
maxKey = maximumSafe . Map.keys
maximumSafe :: (Num a, Ord a) => [a] -> a
maximumSafe [] = 0
maximumSafe ls = maximum ls + 1