{-# OPTIONS -fno-warn-orphans #-}

{-# LANGUAGE DeriveAnyClass, DeriveFoldable, DeriveFunctor, DeriveGeneric,
             FlexibleInstances, ParallelListComp, RecordWildCards,
             ScopedTypeVariables, TupleSections #-}

-- | Definitions of quadratic arithmetic programs, along with their
-- assignment verification functions and the translations from single
-- multiplication- or equality-gates into QAPs and arithmetic circuits
-- into QAPs.
module QAP
  ( QapSet(..)
  , QAP(..)
  , updateAtWire
  , lookupAtWire
  , cnstInpQapSet
  , sumQapSet
  , sumQapSetCnstInp
  , sumQapSetMidOut
  , foldQapSet
  , combineWithDefaults
  , combineInputsWithDefaults
  , combineNonInputsWithDefaults
  , verifyAssignment
  , verificationWitness
  , verificationWitnessZk
  , gateToQAP
  , gateToGenQAP
  , qapSetToMap
  , initialQapSet
  , generateAssignmentGate
  , generateAssignment
  , addMissingZeroes
  , arithCircuitToGenQAP
  , arithCircuitToQAP
  , arithCircuitToQAPFFT
  , createPolynomials
  , createPolynomialsFFT
  ) where

import Protolude hiding (quot, quotRem)

import           Data.Aeson          (FromJSON, ToJSON)
import           Data.Aeson.Types
import           Data.Foldable       (foldr1)
import           Data.Map            (Map, fromList, mapKeys)
import qualified Data.Map            as Map
import qualified Data.Map.Merge.Lazy as Merge

import           Data.Euclidean               (Euclidean(..))
import           Data.Field                   (Field)
import           Data.Field.Galois            (GaloisField, Prime, pow)
import           Data.Poly
import qualified Data.Vector                  as V
import           Text.PrettyPrint.Leijen.Text (Pretty(..), enclose, indent,
                                               lbracket, rbracket, text, vcat,
                                               (<+>))

import           Circuit.Affine     (affineCircuitToAffineMap)
import           Circuit.Arithmetic (ArithCircuit(..), Gate(..), Wire(..),
                                     evalArithCircuit, evalGate)
import qualified FFT

-- | The sets of polynomials/constants as they occur in QAPs, grouped
-- into their constant, input, output and intermediate parts.
data QapSet f = QapSet
  { qapSetConstant :: f
  , qapSetInput :: Map Int f
  , qapSetIntermediate :: Map Int f
  , qapSetOutput :: Map Int f
  } deriving (Show, Eq, Functor, Foldable, Generic, NFData, ToJSON, FromJSON)

-- | Quadratic arithmetic program
data QAP f = QAP
  { qapInputsLeft :: QapSet (VPoly f)
  , qapInputsRight :: QapSet (VPoly f)
  , qapOutputs :: QapSet (VPoly f)
  , qapTarget :: VPoly f
  } deriving (Show, Eq, Generic, NFData, ToJSON, FromJSON)

-- Orphan instances for VPoly
instance (ToJSON f, Generic f) => ToJSON (VPoly f) where
  toJSON = toJSON . unPoly
instance (FromJSON f, Generic f, Eq f, Num f) => FromJSON (VPoly f) where
  parseJSON v = toPoly <$> parseJSON v

instance ToJSON (Prime n)
instance FromJSON (Prime n)

-- | Generalised quadratic arithmetic program: instead of @Poly@, allow
-- any functor.
data GenQAP p f = GenQAP
  { genQapInputsLeft :: QapSet (p f)
  , genQapInputsRight :: QapSet (p f)
  , genQapOutputs :: QapSet (p f)
  , genQapTarget :: p f
  } deriving (Show, Eq, Generic, NFData, ToJSON, FromJSON)

-- Note that we could get "sequence" from the Traversable instance of
-- lists if we had an Applicative/Monad instance of QapSet. There do
-- not seem to be sensible instances of those classes for QapSet.
sequenceQapSet :: [QapSet f] -> QapSet [f]
sequenceQapSet qapSets = QapSet constants inputs mids outputs
  where
    constants = map qapSetConstant qapSets
    inputs = Map.unionsWith (<>) . fmap (fmap pure) $ map qapSetInput qapSets
    mids = Map.unionsWith (<>) . fmap (fmap pure) $ map qapSetIntermediate qapSets
    outputs = Map.unionsWith (<>) . fmap (fmap pure) $ map qapSetOutput qapSets

-- | Create QapSet with only a constant value and empty maps for the
-- rest.
constantQapSet :: g -> QapSet g
constantQapSet g = QapSet
  { qapSetConstant = g
  , qapSetInput = Map.empty
  , qapSetIntermediate = Map.empty
  , qapSetOutput = Map.empty
  }

cnstInpQapSet :: g -> Map Int g -> QapSet g
cnstInpQapSet g inp = QapSet
  { qapSetConstant = g
  , qapSetInput = inp
  , qapSetIntermediate = Map.empty
  , qapSetOutput = Map.empty
  }
-- | Sum all the values contained in a QapSet.
sumQapSet :: Monoid g => QapSet g -> g
sumQapSet = fold

-- | Sum only over constant and input values
sumQapSetCnstInp :: Monoid g => QapSet g -> g
sumQapSetCnstInp (QapSet cnst inp _ _)
  = cnst <> fold inp

-- | Sum only over intermediate and output values
sumQapSetMidOut :: Monoid g => QapSet g -> g
sumQapSetMidOut (QapSet _ _ mid out)
  = fold mid <> fold out

instance Pretty (Ratio Integer) where
  pretty = text . show

instance Pretty f => Pretty (QapSet f) where
  pretty (QapSet constant inps mids outps)
    = vcat
      [ text "constant:" <+> pretty constant
      , text "inputs:"
      , indent 2 $ ppMap inps
      , text "outputs:"
      , indent 2 $ ppMap outps
      , text "intermediates:"
      , indent 2 $ ppMap mids
      ]
    where
       ppMap
        = vcat
        . map (\(ix, x) -> enclose lbracket rbracket (pretty ix) <+> pretty x)
        . Map.toList

combineWithDefaults
  :: (a -> b -> c) -- ^ function to combine the values with
  -> a -- ^ default left value
  -> b -- ^ default right value
  -> QapSet a -- ^ left QapSet
  -> QapSet b -- ^ right QapSet
  -> QapSet c
combineWithDefaults f defaultA defaultB (QapSet cA inpA midA outpA) (QapSet cB inpB midB outpB)
  = QapSet
    { qapSetConstant = f cA cB
    , qapSetInput = combineMaps inpA inpB
    , qapSetIntermediate = combineMaps midA midB
    , qapSetOutput = combineMaps outpA outpB
    }
  where
    combineMaps = Merge.merge missingRight missingLeft matching
    missingLeft = Merge.mapMissing $ const $ f defaultA
    missingRight = Merge.mapMissing $ const $ flip f defaultB
    matching = Merge.zipWithMatched $ const f

combineInputsWithDefaults
  :: (a -> b -> c) -- ^ function to combine the values with
  -> a -- ^ default left value
  -> b -- ^ default right value
  -> QapSet a -- ^ left QapSet
  -> QapSet b -- ^ right QapSet
  -> QapSet c
combineInputsWithDefaults f defaultA defaultB (QapSet cA inpA _ _) (QapSet cB inpB _ _)
  = QapSet
    { qapSetConstant = f cA cB
    , qapSetInput = combineMaps inpA inpB
    , qapSetIntermediate = mempty
    , qapSetOutput = mempty
    }
  where
    combineMaps = Merge.merge missingRight missingLeft matching
    missingLeft = Merge.mapMissing $ const $ f defaultA
    missingRight = Merge.mapMissing $ const $ flip f defaultB
    matching = Merge.zipWithMatched $ const f

combineNonInputsWithDefaults
  :: (a -> b -> c) -- ^ function to combine the values with
  -> a -- ^ default left value
  -> b -- ^ default right value
  -> c -- ^ default constant
  -> QapSet a -- ^ left QapSet
  -> QapSet b -- ^ right QapSet
  -> QapSet c
combineNonInputsWithDefaults f defaultA defaultB defaultC (QapSet _ _ midA outpA) (QapSet _ _ midB outpB)
  = QapSet
    { qapSetConstant = defaultC
    , qapSetInput = mempty
    , qapSetIntermediate = combineMaps midA midB
    , qapSetOutput = combineMaps outpA outpB
    }
  where
    combineMaps = Merge.merge missingRight missingLeft matching
    missingLeft = Merge.mapMissing $ const $ f defaultA
    missingRight = Merge.mapMissing $ const $ flip f defaultB
    matching = Merge.zipWithMatched $ const f

-- | Fold over a QapSet with an operation that is assumed to be
-- commutative.
foldQapSet
  :: (a -> a -> a) -- ^ *commutative* binary operation
  -> QapSet a -- ^ QapSet to fold over
  -> a
foldQapSet = foldr1

-- Alternative to @sequenceGenQap@
createMapGenQap :: Ord k => [GenQAP ((,) k) k] -> GenQAP (Map k) k
createMapGenQap genQaps = GenQAP inpLefts inpRights outputs targets
  where
    inpLefts = fmap Map.fromList . sequenceQapSet . map genQapInputsLeft $ genQaps
    inpRights = fmap Map.fromList . sequenceQapSet . map genQapInputsRight $ genQaps
    outputs =  fmap Map.fromList . sequenceQapSet . map genQapOutputs $ genQaps
    targets = Map.fromList . map genQapTarget $ genQaps

instance (Eq f, Num f, Pretty f, Show f) => Pretty (QAP f) where
  pretty (QAP inpsLeft inpsRight outps target)
    = vcat
      [ text "QAP:"
      , text "inputs left:"
      , indent 2 . text . show $ inpsLeft
      , text "inputs right:"
      , indent 2 . text . show $ inpsRight
      , text "outputs:"
      , indent 2 . text . show $ outps
      , text "target: " <> text (show target)
      ]

instance (Pretty f, Pretty (p f)) => Pretty (GenQAP p f) where
  pretty (GenQAP inpsLeft inpsRight outps target)
    = vcat
      [ text "QAP:"
      , text "inputs left:"
      , indent 2 $ pretty inpsLeft
      , text "inputs right:"
      , indent 2 $ pretty inpsRight
      , text "outputs:"
      , indent 2 $ pretty outps
      , text "target: " <> pretty target
      ]

instance Functor p => Functor (GenQAP p) where
  fmap f (GenQAP inpLeft inpRight outp target)
    = GenQAP (fmap (fmap f) inpLeft)
             (fmap (fmap f) inpRight)
             (fmap (fmap f) outp)
             (fmap f target)

-- | Verify whether an assignment of variables is consistent with the
-- given QAP
verifyAssignment
  :: (Eq f, Field f, Num f)
  => QAP f -- ^ circuit whose evaluation we want to verify
  -> QapSet f -- ^ vector containing the inputs, outputs and
              -- intermediate values (outputs of all the mul-gates)
  -> Bool
verifyAssignment qap assignment = isJust $ verificationWitness qap assignment

-- | Produce the polynomial witnessing the validity of given
-- assignment against the given QAP. Will return @Nothing@ if the
-- assignment is not valid.
--
-- In Pinocchio's terminology: this produces the h(x) such that p(x) =
-- h(x) * t(x) where t(x) is the target polynomial and p(x) is the
-- left input polynomials times the right input polynomials minus the
-- output polynomials.
verificationWitness
  :: forall f . (Eq f, Field f, Num f)
  => QAP f -- ^ circuit whose evaluation we want to verify
  -> QapSet f -- ^ vector containing the inputs, outputs and
              -- intermediate values (outputs of all the mul-gates)
  -> Maybe (VPoly f)
verificationWitness = verificationWitnessZk 0 0 0

verificationWitnessZk
  :: (Eq f, Field f, Num f)
  => f
  -> f
  -> f
  -> QAP f -- ^ circuit whose evaluation we want to verify
  -> QapSet f -- ^ vector containing the inputs, outputs and
              -- intermediate values (outputs of all the mul-gates)
  -> Maybe (VPoly f)
verificationWitnessZk delta1 delta2 delta3 QAP {..} assignment
  = if remainder == 0
    then Just quotient
    else Nothing
  where
    scaleWithAssignment x = combineWithDefaults (\a b -> monomial 0 b * a) 0 0 x assignment
    leftInputPoly
      = (monomial 0 delta1 * qapTarget)
        + sumQap (scaleWithAssignment qapInputsLeft)
    rightInputPoly
      = (monomial 0 delta2 * qapTarget)
        + sumQap (scaleWithAssignment qapInputsRight)
    outputPoly
      = (monomial 0 delta3 * qapTarget)
        + sumQap (scaleWithAssignment qapOutputs)
    sumQap = foldQapSet (+)
    inputOutputPoly
      = (leftInputPoly * rightInputPoly) - outputPoly
    (quotient, remainder) = quotRem inputOutputPoly qapTarget

-- | Lookup the value at the given wire label in the
-- @QapSet@.
lookupAtWire :: Wire -> QapSet a -> Maybe a
lookupAtWire (InputWire ix) QapSet { qapSetInput = inps }
  = Map.lookup ix inps
lookupAtWire (IntermediateWire ix) QapSet { qapSetIntermediate = mids }
  = Map.lookup ix mids
lookupAtWire (OutputWire ix) QapSet { qapSetOutput = outps }
  = Map.lookup ix outps

-- | Update the value at the given wire label in the
-- @QapSet@. (Partial function at the moment.)
updateAtWire :: Wire -> a -> QapSet a -> QapSet a
updateAtWire (InputWire ix) a qs@QapSet { qapSetInput = inps }
  = qs { qapSetInput = Map.insert ix a inps }
updateAtWire (IntermediateWire ix) a qs@QapSet { qapSetIntermediate = mids }
  = qs { qapSetIntermediate = Map.insert ix a mids }
updateAtWire (OutputWire ix) a qs@QapSet { qapSetOutput = outps }
  = qs { qapSetOutput = Map.insert ix a outps }

-- | Update at multiple wires
updateAtWires :: [(Wire, a)] -> QapSet a -> QapSet a
updateAtWires wireVals vars
  = foldl' (\rest (wire, val) -> updateAtWire wire val rest) vars wireVals

-- | Convert a single multiplication- or equality-gate into a QAP
gateToQAP
  :: GaloisField k
  => (Int -> k)
  -> [k]         -- ^ arbitrarily chosen roots
  -> Gate Wire k -- ^ circuit to encode as a QAP
  -> QAP k
gateToQAP primRoots roots
  = createPolynomialsFFT primRoots . addMissingZeroes roots . createMapGenQap . gateToGenQAP roots

-- | Convert a single multiplication gate (with affine circuits for
-- inputs) into a GenQAP
gateToGenQAP
  :: (GaloisField k)
  => [k]         -- ^ arbitrarily chosen roots
  -> Gate Wire k -- ^ circuit to encode as a QAP
  -> [GenQAP ((,) k) k]
gateToGenQAP [root] (Mul l r wire)
  = pure
  . addOutputVals
  . addInputVals
  $ GenQAP
    { genQapInputsLeft = constantQapSet (root, leftInputConst)
    , genQapInputsRight = constantQapSet (root, rightInputConst)
    , genQapOutputs = constantQapSet (root, 0)
    , genQapTarget = (root, 0)
    }
  where
    (leftInputConst, leftInputVector) = affineCircuitToAffineMap l
    (rightInputConst, rightInputVector) = affineCircuitToAffineMap r

    addInputVals (GenQAP left right out t)
      = GenQAP (Map.foldrWithKey updateAtWire left $ fmap (root,) leftInputVector)
               (Map.foldrWithKey updateAtWire right $ fmap (root,) rightInputVector)
               out
               t

    addOutputVals (GenQAP left right out t)
      = GenQAP left
               right
               (updateAtWire wire (root, 1) out)
               t
gateToGenQAP [root0,root1] (Equal i m outputWire)
  = [qap0, qap1]
  where
    qap0 = GenQAP
      { genQapInputsLeft
          = updateAtWires [ (i, (root0, 1))
                          , (m, (root0, 0))
                          , (outputWire, (root0, 0))
                          ]
            $ constantQapSet (root0, 0)
      , genQapInputsRight
          = updateAtWires [ (i, (root0, 0))
                          , (m, (root0, 1))
                          , (outputWire, (root0, 0))
                          ]
            $ constantQapSet (root0, 0)
      , genQapOutputs
          = updateAtWires [ (i, (root0, 0))
                          , (m, (root0, 0))
                          , (outputWire, (root0, 1))
                          ]
            $ constantQapSet (root0, 0)
      , genQapTarget
          = (root0, 0)
      }
    qap1 = GenQAP
      { genQapInputsLeft
          = updateAtWires [ (i, (root1, 0))
                          , (m, (root1, 0))
                          , (outputWire, (root1, -1))
                          ]
            $ constantQapSet (root1, 1)
      , genQapInputsRight
          = updateAtWires [ (i, (root1, 1))
                          , (m, (root1, 0))
                          , (outputWire, (root1, 0))
                          ]
            $ constantQapSet (root1, 0)
      , genQapOutputs
          = updateAtWires [ (i, (root1, 0))
                          , (m, (root1, 0))
                          , (outputWire, (root1, 0))
                          ]
            $ constantQapSet (root1, 0)
      , genQapTarget
          = (root1, 0)
      }
gateToGenQAP (root:roots) (Split inp outputs)
  = if length roots /= length outputs
    then panic "gateToGenQAP: wrong number of roots supplied"
    else qap0:zipWith qaps roots outputs
  where
    qap0 = GenQAP
      { genQapInputsLeft
          = updateAtWires ((inp, (root, 0)):zipWith (\output i -> (output, (root, 2 `pow` i))) outputs [0 :: Integer ..])
            $ constantQapSet (root, 0)
      , genQapInputsRight
          = updateAtWires [(inp, (root, 0))]
            $ constantQapSet (root, 1)
      , genQapOutputs
          = updateAtWires [(inp, (root, 1))]
            $ constantQapSet (root, 0)
      , genQapTarget
          = (root, 0)
      }
    qaps r outp = GenQAP
      { genQapInputsLeft
          = updateAtWire outp (r, 1)
            $ constantQapSet (r, 0)
      , genQapInputsRight
          = updateAtWire outp (r, -1)
            $ constantQapSet (r, 1)
      , genQapOutputs
          = updateAtWire outp (r, 0)
            $ constantQapSet (r, 0)
      , genQapTarget
          = (r, 0)
      }
gateToGenQAP _ _ = panic "gateToGenQAP: wrong number of roots supplied"

-- | For the left input/right input/output polynomials: turn list of
-- coordinates into a polynomial that interpolates the
-- coordinates. For the target polynomial: define it as the product of
-- all monics t_g(x) := x - r_g where r_g is the root corresponding to
-- the gate g.


-- | Naive construction of polynomials using Lagrange interpolation
-- This has terrible complexity at the moment.
-- Use the FFT-based approach if possible.
createPolynomials :: forall k. (GaloisField k) => GenQAP (Map k) k -> QAP k
createPolynomials (GenQAP inpLeft inpRight outp targetRoots)
  = QAP
    { qapInputsLeft = fmap (lagrangeInterpolate . Map.toList) inpLeft
    , qapInputsRight = fmap (lagrangeInterpolate . Map.toList) inpRight
    , qapOutputs = fmap (lagrangeInterpolate . Map.toList) outp
    , qapTarget = foldl' (*) (monomial 0 1) . map ((\root -> toPoly $ V.fromList [-root, 1]) . fst) . Map.toList $ targetRoots
    }
    where
    lagrangeInterpolate :: [(k, k)] -> VPoly k
    lagrangeInterpolate xys = sum
      [ scale 0 f (roots `quot` root x)
      | f <- zipWith (/) ys phis
      | x <- xs
      ]
      where
        xs, ys :: [k]
        (xs,ys) = foldr (\(a, b) ~(as,bs) -> (a:as,b:bs)) ([],[]) xys
        phis :: [k]
        phis = map (eval (deriv roots)) xs
        roots :: VPoly k
        roots = foldl' (\acc xi -> acc * root xi) 1 xs     -- (X - x_0) * ... * (X - x_{n-1})
        root xi = toPoly . V.fromList $ [-xi,  1]          -- (X - x_i)

-- | Create polynomials using FFT-based polynomial operations instead
-- of naive.
createPolynomialsFFT
  :: GaloisField k
  => (Int -> k)       -- ^ function that gives us the primitive 2^k-th root
                      -- of unity
  -> GenQAP (Map k) k -- ^ GenQAP containing the coordinates we want
                      -- to interpolate
  -> QAP k
createPolynomialsFFT primRoots (GenQAP inpLeft inpRight outp targetRoots)
  = QAP
    { qapInputsLeft = fmap (FFT.interpolate primRoots . Map.elems) inpLeft
    , qapInputsRight = fmap (FFT.interpolate primRoots . Map.elems) inpRight
    , qapOutputs = fmap (FFT.interpolate primRoots . Map.elems) outp
    , qapTarget = FFT.fftTargetPoly primRoots (Map.size targetRoots)
    }


-- | Convert an arithmetic circuit into a GenQAP: perform every step
-- of the QAP translation except the final interpolation step.
arithCircuitToGenQAP
  :: GaloisField k
  => [[k]]          -- ^ arbitrarily chosen roots, one for each gate
  -> ArithCircuit k -- ^ circuit to encode as a QAP
  -> GenQAP (Map k) k
arithCircuitToGenQAP rootsPerGate (ArithCircuit gates)
  = addMissingZeroes (concat rootsPerGate)
  . createMapGenQap
  . concat
  $ zipWith gateToGenQAP rootsPerGate gates

-- | Convert an arithmetic circuit into a QAP
arithCircuitToQAP
  :: GaloisField k
  => [[k]] -- ^ arbitrarily chosen roots, one for each gate
  -> ArithCircuit k -- ^ circuit to encode as a QAP
  -> QAP k
arithCircuitToQAP roots circuit =
  createPolynomials
  $ arithCircuitToGenQAP roots circuit

-- | Convert an arithmetic circuit into a QAP
arithCircuitToQAPFFT
  :: GaloisField k
  => (Int -> k)       -- ^ function that gives us the primitive 2^k-th root
                      -- of unity
  -> [[k]]            -- ^ arbitrarily chosen roots, one for each gate
  -> ArithCircuit k   -- ^ circuit to encode as a QAP
  -> QAP k
arithCircuitToQAPFFT primRoots roots circuit =
  createPolynomialsFFT primRoots
  $ arithCircuitToGenQAP roots circuit
-- | Add zeroes for those roots that are missing, to prevent the
-- values in the GenQAP to be too sparse. (We can be sparse in wire
-- values, but not in values at roots, otherwise the interpolation
-- step is incorrect.)
addMissingZeroes
  :: forall f . (Ord f, Num f)
  => [f] -> GenQAP (Map f) f -> GenQAP (Map f) f
addMissingZeroes allRoots (GenQAP inpLeft inpRight outp t)
  = GenQAP (fmap (`Map.union` allZeroes) inpLeft)
           (fmap (`Map.union` allZeroes) inpRight)
           (fmap (`Map.union` allZeroes) outp)
           (t `Map.union` allZeroes)
  where
    allZeroes :: Map f f
    allZeroes = Map.fromList . map (,0) $ allRoots

-- | Generate a valid assignment for a single gate.
generateAssignmentGate
  :: (Bits f, Fractional f)
  => Gate Wire f -- ^ program
  -> Map Int f -- ^ inputs
  -> QapSet f
generateAssignmentGate program inps
  = evalGate
      lookupAtWire
      updateAtWire
      (initialQapSet inps)
      program

initialQapSet
  :: Num f
  => Map Int f -- ^ inputs
  -> QapSet f
initialQapSet inputs = QapSet 1 inputs Map.empty Map.empty

generateAssignment
  :: forall f . (Bits f, Fractional f)
  => ArithCircuit f -- ^ program
  -> Map Int f -- ^ inputs
  -> QapSet f
generateAssignment circuit inputs
  = evalArithCircuit lookupAtWire updateAtWire circuit $ initialQapSet inputs

qapSetToMap :: QapSet g -> Map Int g
qapSetToMap QapSet{..}
  = fromList [(0, qapSetConstant)]
    <> mapKeys ((+) 1) qapSetInput
    <> mapKeys ((+) (1 + numOfInputs)) qapSetIntermediate
    <> mapKeys ((+) (1 + numOfInputs + numOfInterms)) qapSetOutput
  where
    numOfInputs = maxKey qapSetInput
    numOfInterms = maxKey qapSetIntermediate

    maxKey :: Map Int a -> Int
    maxKey = maximumSafe . Map.keys

    maximumSafe :: (Num a, Ord a) => [a] -> a
    maximumSafe [] = 0
    maximumSafe ls = maximum ls + 1