{-# LANGUAGE DeriveAnyClass, DeriveGeneric, LambdaCase, StrictData #-} -- | Definition of arithmetic circuits that only contain addition, -- scalar multiplications and constant gates, along with its direct -- evaluation and translation into affine maps. module Circuit.Affine ( AffineCircuit (..), collectInputsAffine, mapVarsAffine, evalAffineCircuit, affineCircuitToAffineMap, evalAffineMap, dotProduct, ) where import Data.Aeson (FromJSON, ToJSON) import Data.Map (Map) import qualified Data.Map as Map import Protolude import Text.PrettyPrint.Leijen.Text (Doc, Pretty(..), parens, text, (<+>)) -- | Arithmetic circuits without multiplication, i.e. circuits -- describe affine transformations. data AffineCircuit i f = Add (AffineCircuit i f) (AffineCircuit i f) | ScalarMul f (AffineCircuit i f) | ConstGate f | Var i deriving (Read, Eq, Show, Generic, NFData, FromJSON, ToJSON) collectInputsAffine :: Ord i => AffineCircuit i f -> [i] collectInputsAffine = \case Add l r -> collectInputsAffine l ++ collectInputsAffine r ScalarMul _ x -> collectInputsAffine x ConstGate _ -> [] Var i -> [i] instance (Pretty i, Show f) => Pretty (AffineCircuit i f) where pretty = prettyPrec 0 where prettyPrec :: (Pretty i, Show f) => Int -> AffineCircuit i f -> Doc prettyPrec p e = case e of Var v -> pretty v ConstGate f -> text $ show f ScalarMul f e1 -> text (show f) <+> text "*" <+> parensPrec 7 p (prettyPrec p e1) Add e1 e2 -> parensPrec 6 p $ prettyPrec 6 e1 <+> text "+" <+> prettyPrec 6 e2 parensPrec :: Int -> Int -> Doc -> Doc parensPrec opPrec p = if p > opPrec then parens else identity -- | Apply mapping to variable names, i.e. rename variables. (Ideally -- the mapping is injective.) mapVarsAffine :: (i -> j) -> AffineCircuit i f -> AffineCircuit j f mapVarsAffine f = \case Add l r -> Add (mapVarsAffine f l) (mapVarsAffine f r) ScalarMul s expr -> ScalarMul s $ mapVarsAffine f expr ConstGate c -> ConstGate c Var i -> Var $ f i -- | Evaluate the arithmetic circuit without mul-gates on the given -- input. Variable map is assumed to have all the variables referred -- to in the circuit. Failed lookups are currently treated as 0. evalAffineCircuit :: Num f => -- | lookup function for variable mapping (i -> vars -> Maybe f) -> -- | variables vars -> -- | circuit to evaluate AffineCircuit i f -> f evalAffineCircuit lookupVar vars = \case ConstGate f -> f Var i -> fromMaybe 0 $ lookupVar i vars Add l r -> evalAffineCircuit lookupVar vars l + evalAffineCircuit lookupVar vars r ScalarMul scalar expr -> evalAffineCircuit lookupVar vars expr * scalar -- | Convert non-mul circuit to a vector representing the evaluation -- function. We use a @Map@ to represent the potentially sparse vector. affineCircuitToAffineMap :: (Num f, Ord i) => -- | circuit to translate AffineCircuit i f -> -- | constant part and non-constant part (f, Map i f) affineCircuitToAffineMap = \case Var i -> (0, Map.singleton i 1) Add l r -> (constLeft + constRight, Map.unionWith (+) vecLeft vecRight) where (constLeft, vecLeft) = affineCircuitToAffineMap l (constRight, vecRight) = affineCircuitToAffineMap r ScalarMul scalar expr -> (scalar * constExpr, fmap (scalar *) vecExpr) where (constExpr, vecExpr) = affineCircuitToAffineMap expr ConstGate f -> (f, Map.empty) -- | Evaluating the affine map representing the arithmetic circuit -- without mul-gates against inputs. If the input map does not have a -- variable that is referred to in the affine map, then it is treated -- as a 0. evalAffineMap :: (Num f, Ord i) => -- | program split into constant and non-constant part (f, Map i f) -> -- | input variables Map i f -> f evalAffineMap (constPart, linearPart) input = constPart + dotProduct linearPart input dotProduct :: (Num f, Ord i) => Map i f -> Map i f -> f dotProduct inp comp = sum . Map.elems $ Map.mapWithKey (\ix c -> c * Map.findWithDefault 0 ix inp) comp