module Crypto.Lol.Applications.SymmSHE
(
SK, PT, CT
, genSK, genSKWithVar
, encrypt
, errorTerm, errorTermUnrestricted, decrypt, decryptUnrestricted
, addPublic, mulPublic
, rescaleLinearCT, modSwitchPT
, KSLinearHint, KSQuadCircHint
, ksLinearHint, ksQuadCircHint
, keySwitchLinear, keySwitchQuadCirc
, embedSK, embedCT, twaceCT
, TunnelInfo, tunnelInfo
, tunnelCT
, GenSKCtx, EncryptCtx, ToSDCtx, ErrorTermCtx
, DecryptCtx, DecryptUCtx
, AddPublicCtx, MulPublicCtx, ModSwitchPTCtx
, KeySwitchCtx, KSHintCtx
, GenTunnelInfoCtx, TunnelCtx
, SwitchCtx, LWECtx
) where
import qualified Algebra.Additive as Additive (C)
import qualified Algebra.Ring as Ring (C)
import Crypto.Lol as LP hiding (sin)
import Crypto.Lol.Cyclotomic.UCyc (D, UCyc)
import Crypto.Lol.Reflects
import Crypto.Lol.Types.Proto
import Crypto.Proto.Lol.R (R)
import Crypto.Proto.Lol.RqProduct (RqProduct)
import qualified Crypto.Proto.SHE.KSHint as P
import qualified Crypto.Proto.SHE.RqPolynomial as P
import qualified Crypto.Proto.SHE.SecretKey as P
import qualified Crypto.Proto.SHE.TunnelInfo as P
import Control.Applicative hiding ((*>))
import Control.DeepSeq
import Control.Monad as CM
import Control.Monad.Random hiding (lift)
import Data.Maybe
import Data.Traversable as DT
import Data.Typeable
import MathObj.Polynomial as P
data SK r where
SK :: (ToRational v, NFData v) => v -> r -> SK r
type PT rp = rp
data Encoding = MSD | LSD deriving (Show, Eq)
data CT (m :: Factored) zp r'q =
CT
!Encoding
!Int
!zp
!(Polynomial r'q)
deriving (Show)
instance (NFData zp, NFData r'q) => NFData (CT m zp r'q) where
rnf (CT _ k sc cs) = rnf k `seq` rnf sc `seq` rnf cs
instance (NFData r) => NFData (SK r) where
rnf (SK v s) = rnf v `seq` rnf s
type GenSKCtx t m z v =
(ToInteger z, Fact m, CElt t z, ToRational v, NFData v)
genSK :: (GenSKCtx t m z v, MonadRandom rnd)
=> v -> rnd (SK (Cyc t m z))
genSK v = liftM (SK v) $ errorRounded v
genSKWithVar :: (ToInteger z, Fact m, CElt t z, MonadRandom rnd)
=> SK a -> rnd (SK (Cyc t m z))
genSKWithVar (SK v _) = genSK v
type EncryptCtx t m m' z zp zq =
(Mod zp, Ring zp, Ring zq, Lift zp (ModRep zp), Random zq,
Reduce z zq, Reduce (LiftOf zp) zq,
CElt t zq, CElt t zp, CElt t z, CElt t (LiftOf zp),
m `Divides` m')
encrypt :: forall t m m' z zp zq rnd .
(EncryptCtx t m m' z zp zq, MonadRandom rnd)
=> SK (Cyc t m' z) -> PT (Cyc t m zp) -> rnd (CT m zp (Cyc t m' zq))
encrypt (SK svar s) =
let sq = adviseCRT $ reduce s
in \pt -> do
e <- errorCoset svar (embed pt :: PT (Cyc t m' zp))
c1 <- getRandom
return $! CT LSD zero one $ fromCoeffs [reduce e c1 * sq, c1]
type ErrorTermCtx t m' z zp zq =
(Reduce z zq, Lift' zq, CElt t z, CElt t (LiftOf zq), ToSDCtx t m' zp zq)
errorTerm :: (ErrorTermCtx t m' z zp zq)
=> SK (Cyc t m' z) -> CT m zp (Cyc t m' zq) -> Cyc t m' (LiftOf zq)
errorTerm (SK _ s) = let sq = reduce s in
\ct -> let (CT LSD _ _ c) = toLSD ct
in liftCyc Dec $ evaluate c sq
divG' :: (Fact m, CElt t r, IntegralDomain r) => Cyc t m r -> Cyc t m r
divG' = fromJust . divG
type DecryptCtx t m m' z zp zq =
(ErrorTermCtx t m' z zp zq, Reduce (LiftOf zq) zp, IntegralDomain zp,
m `Divides` m', CElt t zp)
decrypt :: forall t m m' z zp zq . (DecryptCtx t m m' z zp zq)
=> SK (Cyc t m' z) -> CT m zp (Cyc t m' zq) -> PT (Cyc t m zp)
decrypt sk ct =
let ct'@(CT LSD k l _) = toLSD ct
in let e :: Cyc t m' zp = reduce $ errorTerm sk ct'
in (scalarCyc l) * twace (iterate divG' e !! k)
type DecryptUCtx t m m' z zp zq =
(Fact m, Fact m', CElt t zp, m `Divides` m',
Reduce z zq, Lift' zq, CElt t z,
ToSDCtx t m' zp zq, Reduce (LiftOf zq) zp, IntegralDomain zp)
errorTermUnrestricted ::
(Reduce z zq, Lift' zq, CElt t z, ToSDCtx t m' zp zq)
=> SK (Cyc t m' z) -> CT m zp (Cyc t m' zq) -> UCyc t m' D (LiftOf zq)
errorTermUnrestricted (SK _ s) = let sq = reduce s in
\ct -> let (CT LSD _ _ c) = toLSD ct
eval = evaluate c sq
in fmap lift $ uncycDec eval
decryptUnrestricted :: (DecryptUCtx t m m' z zp zq)
=> SK (Cyc t m' z) -> CT m zp (Cyc t m' zq) -> PT (Cyc t m zp)
decryptUnrestricted (SK _ s) = let sq = reduce s in
\ct -> let (CT LSD k l c) = toLSD ct
in let eval = evaluate c sq
e = cycDec $ fmap (reduce . lift) $ uncycDec eval
l' = scalarCyc l
in l' * twace (iterate divG' e !! k)
type ToSDCtx t m' zp zq = (Encode zp zq, Fact m', CElt t zq)
toLSD, toMSD :: ToSDCtx t m' zp zq
=> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
toMSD = let (zpScale, zqScale) = lsdToMSD
rqScale = scalarCyc zqScale
in \ct@(CT enc k l c) -> case enc of
MSD -> ct
LSD -> CT MSD k (zpScale * l) ((rqScale *) <$> c)
toLSD = let (zpScale, zqScale) = msdToLSD
rqScale = scalarCyc zqScale
in \ct@(CT enc k l c) -> case enc of
LSD -> ct
MSD -> CT LSD k (zpScale * l) ((rqScale *) <$> c)
rescaleLinearMSD :: (RescaleCyc (Cyc t) zq zq', Fact m')
=> Polynomial (Cyc t m' zq) -> Polynomial (Cyc t m' zq')
rescaleLinearMSD c = case coeffs c of
[] -> fromCoeffs []
[c0] -> fromCoeffs [rescaleDec c0]
[c0,c1] -> let c0' = rescaleDec c0
c1' = rescalePow c1
in fromCoeffs [c0', c1']
_ -> error $ "rescaleLinearMSD: list too long (not linear): " ++
show (length $ coeffs c)
rescaleLinearCT :: (RescaleCyc (Cyc t) zq zq', ToSDCtx t m' zp zq)
=> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq')
rescaleLinearCT ct = let CT MSD k l c = toMSD ct
in CT MSD k l $ rescaleLinearMSD c
type ModSwitchPTCtx t m' zp zp' zq =
(Lift' zp, Reduce (LiftOf zp) zp', ToSDCtx t m' zp zq)
modSwitchPT :: (ModSwitchPTCtx t m' zp zp' zq)
=> CT m zp (Cyc t m' zq) -> CT m zp' (Cyc t m' zq)
modSwitchPT ct = let CT MSD k l c = toMSD ct in
CT MSD k (reduce (lift l)) c
type LWECtx t m' z zq =
(ToInteger z, Reduce z zq, Ring zq, Random zq, Fact m', CElt t z, CElt t zq)
lweSample :: (LWECtx t m' z zq, MonadRandom rnd)
=> SK (Cyc t m' z) -> rnd (Polynomial (Cyc t m' zq))
lweSample (SK svar s) =
let sq = adviseCRT $ negate $ reduce s
in do
e <- errorRounded svar
c1 <- adviseCRT <$> getRandom
return $ fromCoeffs [c1 * sq + reduce (e `asTypeOf` s), c1]
type KSHintCtx gad t m' z zq =
(LWECtx t m' z zq, Reduce (DecompOf zq) zq, Gadget gad zq,
NFElt zq, CElt t (DecompOf zq))
ksHint :: (KSHintCtx gad t m' z zq, MonadRandom rnd)
=> SK (Cyc t m' z) -> Cyc t m' z
-> rnd (Tagged gad [Polynomial (Cyc t m' zq)])
ksHint skout val = do
let valq = reduce val
valgad = encode valq
samples <- DT.mapM (\as -> replicateM (length as) (lweSample skout)) valgad
return $! force $ zipWith (+) <$> (map P.const <$> valgad) <*> samples
(*>>) :: (Ring r, Functor f) => r -> f r -> f r
(*>>) r = fmap (r *)
knapsack :: (Fact m', CElt t zq, r'q ~ Cyc t m' zq)
=> [Polynomial r'q] -> [r'q] -> Polynomial r'q
knapsack hint xs = sum $ zipWith (*>>) (adviseCRT <$> xs) hint
type SwitchCtx gad t m' zq =
(Decompose gad zq, Fact m', CElt t zq, CElt t (DecompOf zq))
switch :: (SwitchCtx gad t m' zq, r'q ~ Cyc t m' zq)
=> Tagged gad [Polynomial r'q] -> r'q -> Polynomial r'q
switch hint c = untag $ knapsack <$> hint <*> (fmap reduce <$> decompose c)
type KeySwitchCtx gad t m' zp zq zq' =
(RescaleCyc (Cyc t) zq' zq, RescaleCyc (Cyc t) zq zq',
ToSDCtx t m' zp zq, SwitchCtx gad t m' zq')
newtype KSLinearHint gad r'q' = KSLHint (Tagged gad [Polynomial r'q']) deriving (NFData)
newtype KSQuadCircHint gad r'q' = KSQHint (Tagged gad [Polynomial r'q']) deriving (NFData)
ksLinearHint :: (KSHintCtx gad t m' z zq', MonadRandom rnd)
=> SK (Cyc t m' z)
-> SK (Cyc t m' z)
-> rnd (KSLinearHint gad (Cyc t m' zq'))
ksLinearHint skout (SK _ sin) = KSLHint <$> ksHint skout sin
keySwitchLinear :: (KeySwitchCtx gad t m' zp zq zq')
=> KSLinearHint gad (Cyc t m' zq') -> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
keySwitchLinear (KSLHint hint) ct =
let CT MSD k l c = toMSD ct
[c0,c1] = coeffs c
c1' = rescalePow c1
in CT MSD k l $ P.const c0 + rescaleLinearMSD (switch hint c1')
ksQuadCircHint :: (KSHintCtx gad t m' z zq', MonadRandom rnd)
=> SK (Cyc t m' z)
-> rnd (KSQuadCircHint gad (Cyc t m' zq'))
ksQuadCircHint sk@(SK _ s) = KSQHint <$> ksHint sk (s*s)
keySwitchQuadCirc :: (KeySwitchCtx gad t m' zp zq zq')
=> KSQuadCircHint gad (Cyc t m' zq') -> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
keySwitchQuadCirc (KSQHint hint) ct =
let CT MSD k l c = toMSD ct
[c0,c1,c2] = coeffs c
c2' = rescalePow c2
in CT MSD k l $ P.fromCoeffs [c0,c1] + rescaleLinearMSD (switch hint c2')
type AddPublicCtx t m m' zp zq = (Lift' zp, Reduce (LiftOf zp) zq,
CElt t zp, CElt t (LiftOf zp), ToSDCtx t m' zp zq, m `Divides` m')
addPublic :: forall t m m' zp zq . (AddPublicCtx t m m' zp zq)
=> Cyc t m zp -> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
addPublic b ct = let CT LSD k l c = toLSD ct in
let linv = scalarCyc $ recip l
b' :: Cyc t m zq = reduce $ liftPow $ linv * (iterate mulG b !! k)
in CT LSD k l $ c + P.const (embed b')
mulScalar :: (Lift' zp, Reduce (LiftOf zp) zq, Fact m', CElt t zq)
=> zp -> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
mulScalar a (CT enc k l c) =
let a' = scalarCyc $ reduce $ lift a
in CT enc k l $ (a' *) <$> c
type MulPublicCtx t m m' zp zq =
(Lift' zp, Reduce (LiftOf zp) zq, Fact m', CElt t zq, m `Divides` m',
CElt t zp, CElt t (LiftOf zp))
mulPublic :: forall t m m' zp zq . (MulPublicCtx t m m' zp zq)
=> Cyc t m zp -> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
mulPublic a (CT enc k l c) =
let a' = embed (reduce $ liftPow a :: Cyc t m zq)
in CT enc k l $ (a' *) <$> c
mulGCT :: (Fact m', CElt t zq)
=> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
mulGCT (CT enc k l c) = CT enc (k+1) l $ mulG <$> c
instance (Lift' zp, Reduce (LiftOf zp) zq, Fact m', CElt t zq,
Eq zp, m `Divides` m', ToSDCtx t m' zp zq)
=> Additive.C (CT m zp (Cyc t m' zq)) where
zero = CT LSD 0 one zero
ct1@(CT enc1 k1 l1 c1) + ct2@(CT enc2 k2 l2 c2)
| l1 /= l2 =
let (CT enc' k' _ c') = mulScalar (l1*(recip l2)) ct1
in (CT enc' k' l2 c') + ct2
| k1 < k2 = iterate mulGCT ct1 !! (k2k1) + ct2
| k1 > k2 = ct1 + iterate mulGCT ct2 !! (k1k2)
| enc1 == LSD && enc2 == MSD = toMSD ct1 + ct2
| enc1 == MSD && enc2 == LSD = ct1 + toMSD ct2
| otherwise = CT enc1 k1 l1 $ c1 + c2
negate (CT enc k l c) = CT enc k l $ negate <$> c
instance (ToSDCtx t m' zp zq, Additive (CT m zp (Cyc t m' zq)))
=> Ring.C (CT m zp (Cyc t m' zq)) where
one = CT LSD 0 one one
ct1@(CT MSD _ _ _) * ct2@(CT MSD _ _ _) = toLSD ct1 * ct2
(CT LSD k1 l1 c1) * (CT d2 k2 l2 c2) =
CT d2 (k1+k2+1) (l1*l2) (mulG <$> c1 * c2)
ct1 * ct2 = ct2 * ct1
type AbsorbGCtx t m' zp zq =
(Lift' zp, IntegralDomain zp, Reduce (LiftOf zp) zq, Ring zq,
Fact m', CElt t (LiftOf zp), CElt t zp, CElt t zq)
absorbGFactors :: forall t zp zq m m' . (AbsorbGCtx t m' zp zq)
=> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
absorbGFactors ct@(CT enc k l c)
| k == 0 = ct
| k > 0 = let d :: Cyc t m' zp = iterate divG' one !! k
rep = adviseCRT $ reduce $ liftPow d
in CT enc 0 l $ (rep *) <$> c
| otherwise = error "k < 0 in absorbGFactors"
embedCT :: (CElt t zq,
r `Divides` r', s `Divides` s', r `Divides` s, r' `Divides` s')
=> CT r zp (Cyc t r' zq) -> CT s zp (Cyc t s' zq)
embedCT (CT d 0 l c) = CT d 0 l (embed <$> c)
embedCT _ = error "embedCT requires 0 factors of g; call aborbGFactors first"
embedSK :: (m `Divides` m') => SK (Cyc t m z) -> SK (Cyc t m' z)
embedSK (SK v s) = SK v $ embed s
twaceCT :: (CElt t zq, r `Divides` r', s' `Divides` r',
s ~ (FGCD s' r))
=> CT r zp (Cyc t r' zq) -> CT s zp (Cyc t s' zq)
twaceCT (CT d 0 l c) = CT d 0 l (twace <$> c)
twaceCT _ = error "twaceCT requires 0 factors of g; call absorbGFactors first"
data TunnelInfo gad t (e :: Factored) (r :: Factored) (s :: Factored) e' r' s' zp zq =
TInfo (Linear t zq e' r' s') [Tagged gad [Polynomial (Cyc t s' zq)]]
instance (NFData (Linear t zq e' r' s'), NFData (Cyc t s' zq))
=> NFData (TunnelInfo gad t e r s e' r' s' zp zq) where
rnf (TInfo l t) = rnf l `seq` rnf t
type GenTunnelInfoCtx t e r s e' r' s' z zp zq gad =
(ExtendLinIdx e r s e' r' s',
e' ~ (e * (r' / r)),
KSHintCtx gad t r' z zq,
Lift zp z, CElt t zp,
CElt t z, e' `Divides` r')
tunnelInfo :: forall gad t e r s e' r' s' z zp zq rnd .
(MonadRandom rnd, GenTunnelInfoCtx t e r s e' r' s' z zp zq gad)
=> Linear t zp e r s
-> SK (Cyc t s' z)
-> SK (Cyc t r' z)
-> rnd (TunnelInfo gad t e r s e' r' s' zp zq)
tunnelInfo f skout (SK _ sin) =
(let f' = extendLin $ lift f :: Linear t z e' r' s'
f'q = reduce f' :: Linear t zq e' r' s'
ps = proxy powBasis (Proxy::Proxy e')
comps = (evalLin f' . (adviseCRT sin *)) <$> ps
in TInfo f'q <$> CM.mapM (ksHint skout) comps)
\\ lcmDivides (Proxy::Proxy r) (Proxy::Proxy e')
type TunnelCtx t r s e' r' s' zp zq gad =
(Fact r, Fact s, e' `Divides` r', e' `Divides` s', CElt t zp,
ToSDCtx t r' zp zq,
AbsorbGCtx t r' zp zq,
SwitchCtx gad t s' zq)
tunnelCT :: forall gad t e r s e' r' s' zp zq .
(TunnelCtx t r s e' r' s' zp zq gad, e ~ FGCD r s)
=> TunnelInfo gad t e r s e' r' s' zp zq
-> CT r zp (Cyc t r' zq)
-> CT s zp (Cyc t s' zq)
tunnelCT (TInfo f'q hints) ct =
(let CT MSD 0 s c = toMSD $ absorbGFactors ct
[c0,c1] = coeffs c
c0' = evalLin f'q c0
c1s = coeffsPow c1 :: [Cyc t e' zq]
c1s' = zipWith switch hints (embed <$> c1s)
c1' = sum c1s'
in CT MSD 0 s $ P.const c0' + c1')
\\ lcmDivides (Proxy::Proxy r) (Proxy::Proxy e')
instance (Protoable r, ProtoType r ~ R) => Protoable (SK r) where
type ProtoType (SK r) = P.SecretKey
toProto (SK v r) = P.SecretKey (toProto r) (realToField v)
fromProto (P.SecretKey r v) = (SK v) <$> fromProto r
instance (Protoable rq, ProtoType rq ~ RqProduct) => Protoable (Polynomial rq) where
type ProtoType (Polynomial rq) = P.RqPolynomial
toProto = P.RqPolynomial . toProto . coeffs
fromProto (P.RqPolynomial x) = fromCoeffs <$> fromProto x
instance (Typeable gad, Protoable r'q', ProtoType r'q' ~ RqProduct)
=> Protoable (KSLinearHint gad r'q') where
type ProtoType (KSLinearHint gad r'q') = P.KSHint
toProto (KSLHint cs) =
P.KSHint
(toProto $ proxy cs (Proxy::Proxy gad))
(toProto $ typeRepFingerprint $ typeRep (Proxy::Proxy gad))
fromProto (P.KSHint poly gadrepr') = do
let gadrepr = toProto $ typeRepFingerprint $ typeRep (Proxy::Proxy gad)
if gadrepr == gadrepr'
then (KSLHint . tag) <$> fromProto poly
else error $ "Expected gadget " ++ (show $ typeRep (Proxy::Proxy gad))
instance (Typeable gad, Protoable r'q', ProtoType r'q' ~ RqProduct)
=> Protoable (KSQuadCircHint gad r'q') where
type ProtoType (KSQuadCircHint gad r'q') = P.KSHint
toProto (KSQHint x) = toProto $ KSLHint x
fromProto y = do
(KSLHint x) <- fromProto y
return $ KSQHint x
instance (Mod zp, Typeable gad,
Protoable (Linear t zq e' r' s'),
Protoable (KSLinearHint gad (Cyc t s' zq)), Reflects s Int, Reflects r Int, Reflects e Int)
=> Protoable (TunnelInfo gad t e r s e' r' s' zp zq) where
type ProtoType (TunnelInfo gad t e r s e' r' s' zp zq) = P.TunnelInfo
toProto (TInfo linf hints) =
P.TunnelInfo
(toProto linf)
(toProto $ KSLHint <$> hints)
(fromIntegral (proxy value (Proxy::Proxy e) :: Int))
(fromIntegral (proxy value (Proxy::Proxy r) :: Int))
(fromIntegral (proxy value (Proxy::Proxy s) :: Int))
(fromIntegral $ proxy modulus (Proxy::Proxy zp))
fromProto (P.TunnelInfo linf hints e r s p) =
let e' = fromIntegral $ (proxy value (Proxy::Proxy e) :: Int)
r' = fromIntegral $ (proxy value (Proxy::Proxy r) :: Int)
s' = fromIntegral $ (proxy value (Proxy::Proxy s) :: Int)
p' = fromIntegral $ proxy modulus (Proxy::Proxy zp)
in if p' == p && e' == e && r' == r && s' == s
then do
linf' <- fromProto linf
hs <- (map (\(KSLHint x) -> x)) <$> fromProto hints
return $ TInfo linf' hs
else error $ "Error reading TunnelInfo proto data:" ++
"\nexpected p=" ++ show p' ++ ", got " ++ show p ++
"\nexpected e=" ++ show e' ++ ", got " ++ show e ++
"\nexpected r=" ++ show r' ++ ", got " ++ show r ++
"\nexpected s=" ++ show s' ++ ", got " ++ show s