{-# LANGUAGE DeriveAnyClass, DeriveGeneric, LambdaCase, ScopedTypeVariables,
StrictData #-}
module Circuit.Arithmetic
( Gate (..),
mapVarsGate,
collectInputsGate,
outputWires,
ArithCircuit (..),
fetchVars,
generateRoots,
validArithCircuit,
Wire (..),
evalGate,
evalArithCircuit,
unsplit,
)
where
import Circuit.Affine (AffineCircuit(..), collectInputsAffine,
evalAffineCircuit, mapVarsAffine)
import Data.Aeson (FromJSON, ToJSON)
import Protolude
import Text.PrettyPrint.Leijen.Text as PP (Pretty(..), hsep, list, parens, text,
vcat)
data Wire
= InputWire Int
| IntermediateWire Int
| OutputWire Int
deriving (Show, Eq, Ord, Generic, NFData, ToJSON, FromJSON)
instance Pretty Wire where
pretty (InputWire v) = text "input_" <> pretty v
pretty (IntermediateWire v) = text "imm_" <> pretty v
pretty (OutputWire v) = text "output_" <> pretty v
data Gate i f
= Mul
{ mulLeft :: AffineCircuit i f,
mulRight :: AffineCircuit i f,
mulOutput :: i
}
| Equal
{ eqInput :: i,
eqMagic :: i,
eqOutput :: i
}
| Split
{ splitInput :: i,
splitOutputs :: [i]
}
deriving (Show, Eq, Generic, NFData, FromJSON, ToJSON)
collectInputsGate :: Ord i => Gate i f -> [i]
collectInputsGate = \case
Mul l r _ -> collectInputsAffine l ++ collectInputsAffine r
_ -> panic "collectInputsGate: only supports mul gates"
outputWires :: Gate i f -> [i]
outputWires = \case
Mul _ _ out -> [out]
Equal _ _ out -> [out]
Split _ outs -> outs
instance (Pretty i, Show f) => Pretty (Gate i f) where
pretty (Mul l r o) =
hsep
[ pretty o,
text ":=",
parens (pretty l),
text "*",
parens (pretty r)
]
pretty (Equal i _ o) =
hsep
[ pretty o,
text ":=",
pretty i,
text "== 0 ? 0 : 1"
]
pretty (Split inp outputs) =
hsep
[ PP.list (map pretty outputs),
text ":=",
text "split",
pretty inp
]
mapVarsGate :: (i -> j) -> Gate i f -> Gate j f
mapVarsGate f = \case
Mul l r o -> Mul (mapVarsAffine f l) (mapVarsAffine f r) (f o)
Equal i j o -> Equal (f i) (f j) (f o)
Split i os -> Split (f i) (fmap f os)
evalGate ::
(Bits f, Fractional f) =>
(i -> vars -> Maybe f) ->
(i -> f -> vars -> vars) ->
vars ->
Gate i f ->
vars
evalGate lookupVar updateVar vars gate =
case gate of
Mul l r outputWire ->
let lval = evalAffineCircuit lookupVar vars l
rval = evalAffineCircuit lookupVar vars r
res = lval * rval
in updateVar outputWire res vars
Equal i m outputWire ->
case lookupVar i vars of
Nothing ->
panic "evalGate: the impossible happened"
Just inp ->
let res = if inp == 0 then 0 else 1
mid = if inp == 0 then 0 else recip inp
in updateVar outputWire res $
updateVar m mid vars
Split i os ->
case lookupVar i vars of
Nothing ->
panic "evalGate: the impossible happened"
Just inp ->
let bool2val True = 1
bool2val False = 0
setWire (ix, oldEnv) currentOut =
( ix + 1,
updateVar currentOut (bool2val $ testBit inp ix) oldEnv
)
in snd . foldl setWire (0, vars) $ os
newtype ArithCircuit f = ArithCircuit [Gate Wire f]
deriving (Eq, Show, Generic, NFData, FromJSON, ToJSON)
instance Show f => Pretty (ArithCircuit f) where
pretty (ArithCircuit gs) = vcat . map pretty $ gs
validArithCircuit ::
ArithCircuit f -> Bool
validArithCircuit (ArithCircuit gates) =
noRefsToUndefinedWires
where
noRefsToUndefinedWires =
fst $
foldl
( \(res, definedWires) gate ->
( res
&& all isNotInput (outputWires gate)
&& all (validWire definedWires) (fetchVarsGate gate),
outputWires gate ++ definedWires
)
)
(True, [])
gates
isNotInput (InputWire _) = False
isNotInput (OutputWire _) = True
isNotInput (IntermediateWire _) = True
validWire _ (InputWire _) = True
validWire _ (OutputWire _) = False
validWire definedWires i@(IntermediateWire _) = i `elem` definedWires
fetchVarsGate (Mul l r _) = fetchVars l ++ fetchVars r
fetchVarsGate (Equal i _ _) = [i]
fetchVarsGate (Split i _) = [i]
fetchVars :: AffineCircuit Wire f -> [Wire]
fetchVars (Var i) = [i]
fetchVars (ConstGate _) = []
fetchVars (ScalarMul _ c) = fetchVars c
fetchVars (Add l r) = fetchVars l ++ fetchVars r
generateRoots ::
Applicative m =>
m f ->
ArithCircuit f ->
m [[f]]
generateRoots _ (ArithCircuit []) =
pure []
generateRoots takeRoot (ArithCircuit (gate : gates)) =
case gate of
Mul {} ->
(\r rs -> [r] : rs)
<$> takeRoot
<*> generateRoots takeRoot (ArithCircuit gates)
Equal {} ->
(\r0 r1 rs -> [r0, r1] : rs)
<$> takeRoot
<*> takeRoot
<*> generateRoots takeRoot (ArithCircuit gates)
Split _ outputs ->
(\r0 rOutputs rRest -> (r0 : rOutputs) : rRest)
<$> takeRoot
<*> traverse (const takeRoot) outputs
<*> generateRoots takeRoot (ArithCircuit gates)
evalArithCircuit ::
forall f vars.
(Bits f, Fractional f) =>
(Wire -> vars -> Maybe f) ->
(Wire -> f -> vars -> vars) ->
ArithCircuit f ->
vars ->
vars
evalArithCircuit lookupVar updateVar (ArithCircuit gates) vars =
foldl' (evalGate lookupVar updateVar) vars gates
unsplit ::
Num f =>
[Wire] ->
AffineCircuit Wire f
unsplit = snd . foldl (\(ix, rest) wire -> (ix + (1 :: Integer), Add rest (ScalarMul (2 ^ ix) (Var wire)))) (0, ConstGate 0)