{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Crypto.Lol.Applications.SymmBGV
(
SK, PT, CT
, genSK, genAnotherSK
, encrypt
, errorTerm, decrypt
, addPublic, mulPublic
, modSwitch, modSwitchPT
, KSHint
, ksLinearHint, ksQuadCircHint
, keySwitchLinear, keySwitchQuadCirc
, embedSK, embedCT, twaceCT
, TunnelHint, tunnelHint
, tunnel
, addCT, mulCT, negateCT
, AddCTCtx, MulCTCtx, NegateCTCtx
, Max, type (+)
, GenSKCtx, EncryptCtx, ToSDCtx
, ErrorTermCtx, DecryptCtx
, AddPublicCtx, MulPublicCtx
, ModSwitchCtx, ModSwitchPTCtx
, KSHintCtx, KeySwitchCtx
, TunnelHintCtx, TunnelCtx
, SwitchCtx, LWECtx
) where
import Crypto.Lol hiding (sin)
import Crypto.Lol.Reflects
import Crypto.Lol.Types.Proto
import qualified Crypto.Proto.BGV.KSHint as P
import qualified Crypto.Proto.BGV.RqPolynomial as P
import qualified Crypto.Proto.BGV.SecretKey as P
import qualified Crypto.Proto.BGV.TunnelHint as P
import Crypto.Proto.Lol.R (R)
import Crypto.Proto.Lol.RqProduct (RqProduct)
import Control.Applicative hiding ((*>))
import Control.DeepSeq
import Control.Monad as CM
import Control.Monad.Random hiding (lift)
import Data.Constraint
import Data.Maybe
import Data.Singletons.Prelude (Max)
import Data.Typeable
import GHC.Generics (Generic)
import GHC.TypeLits (type (+), Nat)
import Unsafe.Coerce
import MathObj.Polynomial as P
data SK r where
SK :: (ToRational v, NFData v) => v -> r -> SK r
type PT rp = rp
data Encoding = MSD | LSD deriving (Show, Eq, Generic, NFData)
data CT (d :: Nat) m zp r'q =
CT
!Encoding
!Int
!zp
!(Polynomial r'q)
deriving (Show, Generic)
deriving instance (NFData zp, NFData r'q) => NFData (CT d m zp r'q)
type GenSKCtx c m z v =
(ToInteger z, RoundedGaussianCyc (c m z), ToRational v, NFData v)
genSK :: (GenSKCtx c m z v, MonadRandom rnd) => v -> rnd (SK (c m z))
genSK v = SK v <$> roundedGaussian v
genAnotherSK :: (GenSKCtx c m z a, MonadRandom rnd) => SK a -> rnd (SK (c m z))
genAnotherSK (SK v _) = genSK v
type EncryptCtx c m m' z zp zq =
(Ring zp, Cyclotomic (c m' zq), Ring (c m' zq), Random (c m' zq),
Reduce (c m' z) (c m' zq), Reduce (LiftOf (c m' zp)) (c m' zq),
CosetGaussianCyc (c m' zp), ExtensionCyc c zp, m `Divides` m')
encrypt :: forall c m m' z zp zq rnd .
(EncryptCtx c m m' z zp zq, MonadRandom rnd)
=> SK (c m' z) -> PT (c m zp) -> rnd (CT 1 m zp (c m' zq))
encrypt (SK svar s) =
let sq = adviseCRT $ reduce s
in \pt -> do
e <- cosetGaussian svar (embed pt :: c m' zp)
c1 <- getRandom
return $ CT LSD zero one $ fromCoeffs [reduce e - c1 * sq, c1]
type ErrorTermCtx c m' z zp zq =
(ToSDCtx c m' zp zq, Ring (c m' zq), Reduce (c m' z) (c m' zq), LiftCyc (c m' zq))
errorTerm :: ErrorTermCtx c m' z zp zq
=> SK (c m' z) -> CT d m zp (c m' zq) -> LiftOf (c m' zq)
errorTerm (SK _ s) = let sq = reduce s in
\ct -> let (CT LSD _ _ c) = toLSD ct
in liftDec $ evaluate c sq
divG' :: (Cyclotomic c) => c -> c
divG' = fromJust . divG
type DecryptCtx c m m' z zp zq =
(ErrorTermCtx c m' z zp zq, Cyclotomic (c m' zp), Module zp (c m zp),
Reduce (LiftOf (c m' zq)) (c m' zp), ExtensionCyc c zp, m `Divides` m')
decrypt :: forall c m m' z zp zq d. DecryptCtx c m m' z zp zq
=> SK (c m' z) -> CT d m zp (c m' zq) -> PT (c m zp)
decrypt sk ct =
let ct'@(CT LSD k l _) = toLSD ct
in let e :: c m' zp = reduce $ errorTerm sk ct'
in l *> twace (iterate divG' e !! k)
type ToSDCtx c m' zp zq =
(Encode zp zq, Cyclotomic (c m' zq), Ring (c m' zq), Module zq (c m' zq))
toLSD, toMSD :: forall c m m' zp zq d.
ToSDCtx c m' zp zq => CT d m zp (c m' zq) -> CT d m zp (c m' zq)
toMSD = let (zpScale, zqScale :: zq) = lsdToMSD
in \ct@(CT enc k l c) -> case enc of
MSD -> ct
LSD -> CT MSD k (zpScale * l) (zqScale *> c)
toLSD = let (zpScale, zqScale :: zq) = msdToLSD
in \ct@(CT enc k l c) -> case enc of
LSD -> ct
MSD -> CT LSD k (zpScale * l) (zqScale *> c)
modSwitchMSD :: (RescaleCyc (c m') zq zq') => Polynomial (c m' zq) -> Polynomial (c m' zq')
modSwitchMSD c = case coeffs c of
[] -> fromCoeffs []
c0:c' -> fromCoeffs $ rescaleDec c0 : map rescalePow c'
type ModSwitchCtx c m' zp zq zq' =
(RescaleCyc (c m') zq zq', ToSDCtx c m' zp zq)
modSwitch :: (ModSwitchCtx c m' zp zq zq')
=> CT d m zp (c m' zq) -> CT d m zp (c m' zq')
modSwitch ct = let CT MSD k l c = toMSD ct
in CT MSD k l $ modSwitchMSD c
type ModSwitchPTCtx c m' zp zp' zq =
(Lift' zp, Reduce (LiftOf zp) zp', ToSDCtx c m' zp zq)
modSwitchPT :: ModSwitchPTCtx c m' zp zp' zq => CT d m zp (c m' zq) -> CT d m zp' (c m' zq)
modSwitchPT ct = let CT MSD k l c = toMSD ct in
CT MSD k (reduce (lift l)) c
type LWECtx c m' z zq =
(Cyclotomic (c m' zq), RoundedGaussianCyc (c m' z), Reduce (c m' z) (c m' zq),
Random (c m' zq), Ring (c m' zq))
lweSample :: (LWECtx c m' z zq, MonadRandom rnd)
=> SK (c m' z) -> rnd (Polynomial (c m' zq))
lweSample (SK svar s) =
let sq = adviseCRT $ negate $ reduce s
in do
e <- roundedGaussian svar
c1 <- adviseCRT <$> getRandom
return $ fromCoeffs [c1 * sq + reduce (e `asTypeOf` s), c1]
newtype KSHint gad r'q' = KSHint [Polynomial r'q']
deriving Generic
deriving anyclass NFData
type KSHintCtx gad c m' z zq = (LWECtx c m' z zq, Gadget gad (c m' zq))
ksHint :: forall gad c m' z zq rnd . (KSHintCtx gad c m' z zq, MonadRandom rnd)
=> SK (c m' z) -> c m' z -> rnd (KSHint gad (c m' zq))
ksHint skout val = do
let valgad = encode @gad $ reduce val
samples <- replicateM (length valgad) (lweSample skout)
return $ KSHint $ zipWith (+) (P.const <$> valgad) samples
ksLinearHint :: forall gad c m' z zq' rnd . (KSHintCtx gad c m' z zq', MonadRandom rnd)
=> SK (c m' z)
-> SK (c m' z)
-> rnd (KSHint gad (c m' zq'))
ksLinearHint skout (SK _ sin) = ksHint skout sin
ksQuadCircHint :: forall gad c m' z zq' rnd .
(KSHintCtx gad c m' z zq', Ring (c m' z), MonadRandom rnd)
=> SK (c m' z) -> rnd (KSHint gad (c m' zq'))
ksQuadCircHint sk@(SK _ s) = ksHint sk (s*s)
(*>>) :: (Ring r, Functor f) => r -> f r -> f r
(*>>) r = fmap (r *)
knapsack :: (r'q ~ c m' zq, Cyclotomic (c m' zq), Ring (c m' zq))
=> [Polynomial r'q] -> [r'q] -> Polynomial r'q
knapsack hint xs = sum $ zipWith (*>>) (adviseCRT <$> xs) hint
type SwitchCtx gad c m' zq =
(Cyclotomic (c m' zq), Ring (c m' zq), Decompose gad (c m' zq),
Reduce (DecompOf (c m' zq)) (c m' zq))
keySwitch :: forall gad c m' zq r'q . (SwitchCtx gad c m' zq, r'q ~ c m' zq)
=> KSHint gad r'q -> r'q -> Polynomial r'q
keySwitch (KSHint hint) c = knapsack hint $ reduce <$> decompose @gad c
type KeySwitchCtx gad c m' zp zq' =
(ToSDCtx c m' zp zq', SwitchCtx gad c m' zq')
keySwitchLinear :: forall gad c m m' zp zq' . KeySwitchCtx gad c m' zp zq'
=> KSHint gad (c m' zq')
-> CT 1 m zp (c m' zq')
-> CT 1 m zp (c m' zq')
keySwitchLinear hint ct =
let CT MSD k l c = toMSD ct
in case coeffs c of
[] -> ct
[_] -> ct
[c0,c1] -> CT MSD k l $ P.const c0 + keySwitch hint c1
_ -> error "keySwitchLinear: internal error"
keySwitchQuadCirc :: forall gad c m m' zp zq' . KeySwitchCtx gad c m' zp zq'
=> KSHint gad (c m' zq')
-> CT 2 m zp (c m' zq')
-> CT 1 m zp (c m' zq')
keySwitchQuadCirc hint ct =
let CT MSD k l c = toMSD ct
in case coeffs c of
[] -> CT MSD k l c
[_] -> CT MSD k l c
[_,_] -> CT MSD k l c
[c0,c1,c2] -> CT MSD k l $ P.fromCoeffs [c0,c1] + keySwitch hint c2
_ -> error "keySwitchQuadCirc: internal error"
type AddPublicCtx c m m' zp zq =
(ToSDCtx c m' zp zq, Cyclotomic (c m zp), Module zp (c m zp),
LiftCyc (c m zp), Reduce (LiftOf (c m zp)) (c m zq),
ExtensionCyc c zq, m `Divides` m')
addPublic :: forall c m m' zp zq d. AddPublicCtx c m m' zp zq
=> c m zp -> CT d m zp (c m' zq) -> CT d m zp (c m' zq)
addPublic b ct = let CT LSD k l c = toLSD ct in
let
b' :: c m zq = reduce $ liftPow $ recip l *> (iterate mulG b !! k)
in CT LSD k l $ c + P.const (embed b')
mulScalar :: forall zp zq c m m' d .
(Lift' zp, Reduce (LiftOf zp) zq, Module zq (c m' zq))
=> zp -> CT d m zp (c m' zq) -> CT d m zp (c m' zq)
mulScalar a (CT enc k l c) =
let a' :: zq = reduce $ lift a
in CT enc k l $ (a' *>) <$> c
type MulPublicCtx c m m' zp zq =
(LiftCyc (c m zp), Reduce (LiftOf (c m zp)) (c m zq),
ExtensionCyc c zq, m `Divides` m', Ring (c m' zq))
mulPublic :: forall c m m' zp zq d. MulPublicCtx c m m' zp zq
=> c m zp -> CT d m zp (c m' zq) -> CT d m zp (c m' zq)
mulPublic a (CT enc k l r) =
let a' = embed (reduce $ liftPow a :: c m zq)
in CT enc k l $ (a' *) <$> r
mulGCT :: (Cyclotomic r'q) => CT d m zp r'q -> CT d m zp r'q
mulGCT (CT enc k l c) = CT enc (k+1) l $ mulG <$> c
type AddCTCtx c m m' zp zq =
(Lift' zp, Reduce (LiftOf zp) zq, ToSDCtx c m' zp zq,
Eq zp, m `Divides` m')
addCT :: forall c m m' zp zq d1 d2 . AddCTCtx c m m' zp zq
=> CT d1 m zp (c m' zq) -> CT d2 m zp (c m' zq) -> CT (Max d1 d2) m zp (c m' zq)
addCT ct1@(CT enc1 k1 l1 c1) ct2@(CT enc2 k2 l2 c2)
| l1 /= l2 =
let (CT enc' k' _ c') = mulScalar (l1 * recip l2) ct1
ct1' = CT @d1 enc' k' l2 c'
in addCT ct1' ct2
| k1 < k2 = addCT (iterate mulGCT ct1 !! (k2-k1)) ct2
| k1 > k2 = addCT ct1 $ iterate mulGCT ct2 !! (k1-k2)
| enc1 == LSD && enc2 == MSD = addCT (toMSD ct1) ct2
| enc1 == MSD && enc2 == LSD = addCT ct1 $ toMSD ct2
| otherwise = CT enc1 k1 l1 $ c1 + c2
type NegateCTCtx c m' zq = (Additive (c m' zq))
negateCT :: NegateCTCtx c m' zq => CT d m zp (c m' zq) -> CT d m zp (c m' zq)
negateCT (CT enc k l c) = CT enc k l $ negate <$> c
symmetricAddition :: forall d1 d2. (() :- ((d1 + d2) ~ (d2 + d1)))
symmetricAddition = Sub $ unsafeCoerce (Dict :: Dict ())
type MulCTCtx c m' zp zq = (ToSDCtx c m' zp zq)
mulCT :: forall c m m' zp zq d1 d2 . MulCTCtx c m' zp zq => CT d1 m zp (c m' zq) -> CT d2 m zp (c m' zq) -> CT (d1 + d2) m zp (c m' zq)
mulCT ct1@(CT MSD _ _ _) ct2@(CT MSD _ _ _) = mulCT (toLSD ct1) ct2
mulCT (CT LSD k1 l1 c1) (CT d2 k2 l2 c2) =
CT d2 (k1+k2+1) (l1*l2) (mulG <$> c1 * c2)
mulCT ct1 ct2 = mulCT ct2 ct1 \\ symmetricAddition @d1 @d2
type AbsorbGCtx c m' zp zq =
(Ring (c m' zp), Ring (c m' zq), Cyclotomic (c m' zp), Cyclotomic (c m' zq),
LiftCyc (c m' zp), Reduce (LiftOf (c m' zp)) (c m' zq))
absorbGFactors :: forall c zp zq m m' d. AbsorbGCtx c m' zp zq
=> CT d m zp (c m' zq) -> CT d m zp (c m' zq)
absorbGFactors ct@(CT enc k l r)
| k == 0 = ct
| k > 0 = let d :: c m' zp = iterate divG' one !! k
rep = adviseCRT $ reduce $ liftPow d
in CT enc 0 l $ (rep *) <$> r
| otherwise = error "k < 0 in absorbGFactors"
embedCT :: (r `Divides` r', s `Divides` s', r `Divides` s, r' `Divides` s', ExtensionCyc c zq, AbsorbGCtx c r' zp zq)
=> CT d r zp (c r' zq) -> CT d s zp (c s' zq)
embedCT ct = let (CT d _ l c) = absorbGFactors ct
in CT d 0 l (embed <$> c)
embedSK :: (m `Divides` m', ExtensionCyc c z) => SK (c m z) -> SK (c m' z)
embedSK (SK v s) = SK v $ embed s
twaceCT :: (r `Divides` r', s' `Divides` r', s ~ (FGCD s' r), ExtensionCyc c zq, AbsorbGCtx c r' zp zq)
=> CT d r zp (c r' zq) -> CT d s zp (c s' zq)
twaceCT ct = let (CT d _ l c) = absorbGFactors ct
in CT d 0 l (twace <$> c)
data TunnelHint gad c e r s e' r' s' zp zq =
THint (Linear c e' r' s' zq) [KSHint gad (c s' zq)]
deriving (Generic)
deriving instance (Show (c s' zq), Show (KSHint gad (c s' zq)))
=> Show (TunnelHint gad c e r s e' r' s' zp zq)
deriving instance (NFData (c s' zq))
=> NFData (TunnelHint gad c e r s e' r' s' zp zq)
type TunnelHintCtx c e r s e' r' s' z zp zq' gad =
(ExtendLinCtx c e r s e' r' s' zp,
e' ~ (e * (r' / r)),
Fact r, z ~ LiftOf zp,
KSHintCtx gad c s' z zq',
LiftCyc (c s zp), LiftOf (c s zp) ~ c s z,
ExtensionCyc c z, e' `Divides` r',
Reduce (c s' z) (c s' zq'),
Cyclotomic (c r' z),
Ring (c r' z), Ring (c s' z), Random (c s' zq'), Gadget gad (c s' zq'))
tunnelHint :: forall gad c e r s e' r' s' z zp zq' rnd .
(MonadRandom rnd, TunnelHintCtx c e r s e' r' s' z zp zq' gad)
=> Linear c e r s zp
-> SK (c s' z)
-> SK (c r' z)
-> rnd (TunnelHint gad c e r s e' r' s' zp zq')
tunnelHint f skout (SK _ sin) =
(let f' = extendLin $ liftLin (Just Pow) f :: Linear c e' r' s' z
ps = proxy powBasis (Proxy::Proxy e')
comps = evalLin f' . (adviseCRT sin *) <$> ps
in THint (reduce f') <$> CM.mapM (ksHint skout) comps)
\\ lcmDivides @r @e'
type TunnelCtx c r s e' r' s' zp zq' gad =
(Fact r, Fact s, e' `Divides` r', e' `Divides` s', ExtensionCyc c zq',
ToSDCtx c r' zp zq',
AbsorbGCtx c r' zp zq',
SwitchCtx gad c s' zq')
tunnel :: forall gad c e r s e' r' s' zp zq' d.
(TunnelCtx c r s e' r' s' zp zq' gad)
=> TunnelHint gad c e r s e' r' s' zp zq'
-> CT d r zp (c r' zq')
-> CT d s zp (c s' zq')
tunnel (THint f'q hints) ct =
(let CT MSD 0 s c = toMSD $ absorbGFactors ct
[c0,c1] = coeffs c
c0' = evalLin f'q c0
c1s = coeffsPow c1 :: [c e' zq']
c1s' = zipWith keySwitch hints (embed <$> c1s)
c1' = sum c1s'
in CT MSD 0 s $ P.const c0' + c1')
\\ lcmDivides @r @e'
instance (NFData r) => NFData (SK r) where
rnf (SK v s) = rnf v `seq` rnf s
instance Show r => Show (SK r) where
show (SK v r) = "(SK " ++ show (toRational v) ++ " " ++ show r ++ ")"
instance (Protoable r, ProtoType r ~ R) => Protoable (SK r) where
type ProtoType (SK r) = P.SecretKey
toProto (SK v r) = P.SecretKey (toProto r) (realToField v)
fromProto (P.SecretKey r v) = SK v <$> fromProto r
instance (Protoable rq, ProtoType rq ~ RqProduct) => Protoable (Polynomial rq) where
type ProtoType (Polynomial rq) = P.RqPolynomial
toProto = P.RqPolynomial . toProto . coeffs
fromProto (P.RqPolynomial x) = fromCoeffs <$> fromProto x
instance (Typeable gad, Protoable r'q', ProtoType r'q' ~ RqProduct)
=> Protoable (KSHint gad r'q') where
type ProtoType (KSHint gad r'q') = P.KSHint
toProto (KSHint cs) =
P.KSHint
(toProto cs)
(toProto $ typeRepFingerprint $ typeRep (Proxy::Proxy gad))
fromProto (P.KSHint poly gadrepr') = do
let gadrepr = toProto $ typeRepFingerprint $ typeRep (Proxy::Proxy gad)
if gadrepr == gadrepr'
then KSHint <$> fromProto poly
else error $ "Expected gadget " ++ show (typeRep (Proxy::Proxy gad))
instance (Mod zp, Typeable gad,
Protoable (Linear c e' r' s' zq), Protoable (KSHint gad (c s' zq)),
Reflects s Int, Reflects r Int, Reflects e Int)
=> Protoable (TunnelHint gad c e r s e' r' s' zp zq) where
type ProtoType (TunnelHint gad c e r s e' r' s' zp zq) = P.TunnelHint
toProto (THint linf hints) =
P.TunnelHint
(toProto linf)
(toProto hints)
(fromIntegral (value @e :: Int))
(fromIntegral (value @r :: Int))
(fromIntegral (value @s :: Int))
(fromIntegral $ modulus @zp)
fromProto (P.TunnelHint linf hints e r s p) =
let e' = fromIntegral (value @e :: Int)
r' = fromIntegral (value @r :: Int)
s' = fromIntegral (value @s :: Int)
p' = fromIntegral (modulus @zp)
in if p' == p && e' == e && r' == r && s' == s
then do
linf' <- fromProto linf
hs <- fromProto hints
return $ THint linf' hs
else error $ "Error reading TunnelHint proto data:" ++
"\nexpected p=" ++ show p' ++ ", got " ++ show p ++
"\nexpected e=" ++ show e' ++ ", got " ++ show e ++
"\nexpected r=" ++ show r' ++ ", got " ++ show r ++
"\nexpected s=" ++ show s' ++ ", got " ++ show s