module Crypto.Lol.Applications.SymmSHE
(
SK, PT, CT
, genSK
, encrypt
, errorTerm, errorTermUnrestricted, decrypt, decryptUnrestricted
, addScalar, addPublic, mulPublic
, rescaleLinearCT, modSwitchPT
, keySwitchLinear, keySwitchQuadCirc
, embedSK, embedCT, twaceCT
, tunnelCT
, GenSKCtx, EncryptCtx, ToSDCtx, ErrorTermCtx
, DecryptCtx, DecryptUCtx
, AddScalarCtx, AddPublicCtx, MulPublicCtx, ModSwitchPTCtx
, SwitchCtx, KeySwitchCtx, KSHintCtx
, TunnelCtx
) where
import qualified Algebra.Additive as Additive (C)
import qualified Algebra.Ring as Ring (C)
import Crypto.Lol.Cyclotomic.Cyc
import Crypto.Lol.Cyclotomic.UCyc (UCyc, D)
import Crypto.Lol.Cyclotomic.Linear
import Crypto.Lol.Gadget
import Crypto.Lol.LatticePrelude as LP hiding (sin)
import Control.Applicative hiding ((*>))
import Control.DeepSeq
import Control.Monad as CM
import Control.Monad.Random
import Data.Maybe
import Data.Traversable as DT
import MathObj.Polynomial as P
data SK r where
SK :: (ToRational v, NFData v) => v -> r -> SK r
type PT rp = rp
data Encoding = MSD | LSD deriving (Show, Eq)
data CT (m :: Factored) zp r'q =
CT
!Encoding
!Int
!zp
!(Polynomial r'q)
deriving (Show)
instance (NFData zp, NFData r'q) => NFData (CT m zp r'q) where
rnf (CT _ k sc cs) = rnf k `seq` rnf sc `seq` rnf cs
instance (NFData r) => NFData (SK r) where
rnf (SK v s) = rnf v `seq` rnf s
type GenSKCtx t m z v =
(ToInteger z, Fact m, CElt t z, ToRational v, NFData v)
genSK :: (GenSKCtx t m z v, MonadRandom rnd)
=> v -> rnd (SK (Cyc t m z))
genSK v = liftM (SK v) $ errorRounded v
type EncryptCtx t m m' z zp zq =
(Mod zp, Ring zp, Ring zq, Lift zp (ModRep zp), Random zq,
Reduce z zq, Reduce (LiftOf zp) zq,
CElt t zq, CElt t zp, CElt t z, CElt t (LiftOf zp),
m `Divides` m')
encrypt :: forall t m m' z zp zq rnd . (EncryptCtx t m m' z zp zq, MonadRandom rnd)
=> SK (Cyc t m' z) -> PT (Cyc t m zp) -> rnd (CT m zp (Cyc t m' zq))
encrypt (SK svar s) =
let sq = adviseCRT $ reduce s
in \pt -> do
e <- errorCoset svar (embed pt :: PT (Cyc t m' zp))
c1 <- getRandom
return $! CT LSD zero one $ fromCoeffs [reduce e c1 * sq, c1]
type ErrorTermCtx t m' z zp zq =
(Reduce z zq, Lift' zq, CElt t z, CElt t (LiftOf zq), ToSDCtx t m' zp zq)
errorTerm :: (ErrorTermCtx t m' z zp zq)
=> SK (Cyc t m' z) -> CT m zp (Cyc t m' zq) -> Cyc t m' (LiftOf zq)
errorTerm (SK _ s) = let sq = reduce s in
\ct -> let (CT LSD _ _ c) = toLSD ct
in liftCyc Dec $ evaluate c sq
divG' :: (Fact m, CElt t r) => Cyc t m r -> Cyc t m r
divG' = fromJust . divG
type DecryptCtx t m m' z zp zq =
(ErrorTermCtx t m' z zp zq, Reduce (LiftOf zq) zp,
m `Divides` m', CElt t zp)
decrypt :: forall t m m' z zp zq . (DecryptCtx t m m' z zp zq)
=> SK (Cyc t m' z) -> CT m zp (Cyc t m' zq) -> PT (Cyc t m zp)
decrypt sk ct =
let ct'@(CT LSD k l _) = toLSD ct
in let e :: Cyc t m' zp = reduce $ errorTerm sk ct'
in (scalarCyc l) * twace (iterate divG' e !! k)
type DecryptUCtx t m m' z zp zq =
(Fact m, Fact m', CElt t zp, m `Divides` m',
Reduce z zq, Lift' zq, CElt t z,
ToSDCtx t m' zp zq, Reduce (LiftOf zq) zp)
errorTermUnrestricted ::
(Reduce z zq, Lift' zq, CElt t z, ToSDCtx t m' zp zq)
=> SK (Cyc t m' z) -> CT m zp (Cyc t m' zq) -> UCyc t m' D (LiftOf zq)
errorTermUnrestricted (SK _ s) = let sq = reduce s in
\ct -> let (CT LSD _ _ c) = toLSD ct
eval = evaluate c sq
in fmap lift $ uncycDec eval
decryptUnrestricted :: (DecryptUCtx t m m' z zp zq)
=> SK (Cyc t m' z) -> CT m zp (Cyc t m' zq) -> PT (Cyc t m zp)
decryptUnrestricted (SK _ s) = let sq = reduce s in
\ct -> let (CT LSD k l c) = toLSD ct
in let eval = evaluate c sq
e = cycDec $ fmap (reduce . lift) $ uncycDec eval
l' = scalarCyc l
in l' * twace (iterate divG' e !! k)
type ToSDCtx t m' zp zq = (Encode zp zq, Fact m', CElt t zq)
toLSD, toMSD :: ToSDCtx t m' zp zq
=> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
toMSD = let (zpScale, zqScale) = lsdToMSD
rqScale = scalarCyc zqScale
in \ct@(CT enc k l c) -> case enc of
MSD -> ct
LSD -> CT MSD k (zpScale * l) ((rqScale *) <$> c)
toLSD = let (zpScale, zqScale) = msdToLSD
rqScale = scalarCyc zqScale
in \ct@(CT enc k l c) -> case enc of
LSD -> ct
MSD -> CT LSD k (zpScale * l) ((rqScale *) <$> c)
rescaleLinearMSD :: (RescaleCyc (Cyc t) zq zq', Fact m')
=> Polynomial (Cyc t m' zq) -> Polynomial (Cyc t m' zq')
rescaleLinearMSD c = case coeffs c of
[] -> fromCoeffs []
[c0] -> fromCoeffs [rescaleCyc Dec c0]
[c0,c1] -> let c0' = rescaleCyc Dec c0
c1' = rescaleCyc Pow c1
in fromCoeffs [c0', c1']
_ -> error $ "rescaleLinearMSD: list too long (not linear): " ++
show (length $ coeffs c)
rescaleLinearCT :: (RescaleCyc (Cyc t) zq zq', ToSDCtx t m' zp zq)
=> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq')
rescaleLinearCT ct = let CT MSD k l c = toMSD ct
in CT MSD k l $ rescaleLinearMSD c
type ModSwitchPTCtx t m' zp zp' zq =
(Lift' zp, Reduce (LiftOf zp) zp', ToSDCtx t m' zp zq)
modSwitchPT :: (ModSwitchPTCtx t m' zp zp' zq)
=> CT m zp (Cyc t m' zq) -> CT m zp' (Cyc t m' zq)
modSwitchPT ct = let CT MSD k l c = toMSD ct in
CT MSD k (reduce (lift l)) c
type LWECtx t m' z zq =
(ToInteger z, Reduce z zq, Ring zq, Random zq, Fact m', CElt t z, CElt t zq)
lweSample :: (LWECtx t m' z zq, MonadRandom rnd)
=> SK (Cyc t m' z) -> rnd (Polynomial (Cyc t m' zq))
lweSample (SK svar s) =
let sq = adviseCRT $ negate $ reduce s
in do
e <- errorRounded svar
c1 <- adviseCRT <$> getRandom
return $ fromCoeffs [c1 * sq + reduce (e `asTypeOf` s), c1]
type KSHintCtx gad t m' z zq =
(LWECtx t m' z zq, Reduce (DecompOf zq) zq, Gadget gad zq,
NFElt zq, CElt t (DecompOf zq))
ksHint :: (KSHintCtx gad t m' z zq, MonadRandom rnd)
=> SK (Cyc t m' z) -> Cyc t m' z
-> rnd (Tagged gad [Polynomial (Cyc t m' zq)])
ksHint skout val = do
let valq = reduce val
valgad = encode valq
samples <- DT.mapM (\as -> replicateM (length as) (lweSample skout)) valgad
return $! force $ zipWith (+) <$> (map P.const <$> valgad) <*> samples
(*>>) :: (Ring r, Functor f) => r -> f r -> f r
(*>>) r = fmap (r *)
knapsack :: (Fact m', CElt t zq, r'q ~ Cyc t m' zq)
=> [Polynomial r'q] -> [r'q] -> Polynomial r'q
knapsack hint xs = sum $ zipWith (*>>) (adviseCRT <$> xs) hint
type SwitchCtx gad t m' zq =
(Decompose gad zq, Fact m', CElt t zq, CElt t (DecompOf zq))
switch :: (SwitchCtx gad t m' zq, r'q ~ Cyc t m' zq)
=> Tagged gad [Polynomial r'q] -> r'q -> Polynomial r'q
switch hint c = untag $ knapsack <$> hint <*> (fmap reduce <$> decompose c)
type KeySwitchCtx gad t m' zp zq zq' =
(RescaleCyc (Cyc t) zq' zq, RescaleCyc (Cyc t) zq zq',
ToSDCtx t m' zp zq, SwitchCtx gad t m' zq')
keySwitchLinear :: forall gad t m' zp zq zq' z rnd m .
(KeySwitchCtx gad t m' zp zq zq', KSHintCtx gad t m' z zq', MonadRandom rnd)
=> SK (Cyc t m' z)
-> SK (Cyc t m' z)
-> TaggedT (gad, zq') rnd (CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq))
keySwitchLinear skout (SK _ sin) = tagT $ do
hint :: Tagged gad [Polynomial (Cyc t m' zq')] <- ksHint skout sin
return $! hint `seq`
(\ct -> let CT MSD k l c = toMSD ct
[c0,c1] = coeffs c
c1' = rescaleCyc Pow c1
in CT MSD k l $ P.const c0 + rescaleLinearMSD (switch hint c1'))
keySwitchQuadCirc :: forall gad t m' zp zq zq' z m rnd .
(KeySwitchCtx gad t m' zp zq zq', KSHintCtx gad t m' z zq', MonadRandom rnd)
=> SK (Cyc t m' z)
-> TaggedT (gad, zq') rnd (CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq))
keySwitchQuadCirc sk@(SK _ s) = tagT $ do
hint :: Tagged gad [Polynomial (Cyc t m' zq')] <- ksHint sk (s*s)
return $ hint `seq` (\ct ->
let CT MSD k l c = toMSD ct
[c0,c1,c2] = coeffs c
c2' = rescaleCyc Pow c2
in CT MSD k l $ P.fromCoeffs [c0,c1] + rescaleLinearMSD (switch hint c2'))
type AddScalarCtx t m' zp zq =
(Lift' zp, Reduce (LiftOf zp) zq, ToSDCtx t m' zp zq)
addScalar :: (AddScalarCtx t m' zp zq)
=> zp -> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
addScalar b ct =
let (l,c) = case toLSD ct of
CT LSD 0 l c -> (l,c)
CT LSD _ _ _ -> error "cannot add public scalar to ciphertext with 'g' factors"
_ -> error "internal error: addScalar"
b' = scalarCyc (reduce $ lift $ b * recip l)
in CT LSD 0 l $ c + P.const b'
type AddPublicCtx t m m' zp zq =
(Lift' zp, Reduce (LiftOf zp) zq, m `Divides` m',
CElt t zp, CElt t (LiftOf zp), ToSDCtx t m' zp zq)
addPublic :: forall t m m' zp zq . (AddPublicCtx t m m' zp zq)
=> Cyc t m zp -> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
addPublic b ct = let CT LSD k l c = toLSD ct in
let linv = scalarCyc $ recip l
b' :: Cyc t m zq = reduce $ liftCyc Pow $ linv * (iterate mulG b !! k)
in CT LSD k l $ c + P.const (embed b')
type MulPublicCtx t m m' zp zq =
(Lift' zp, Reduce (LiftOf zp) zq, Ring zq, m `Divides` m',
CElt t zp, CElt t (LiftOf zp), CElt t zq)
mulPublic :: forall t m m' zp zq . (MulPublicCtx t m m' zp zq)
=> Cyc t m zp -> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
mulPublic a (CT enc k l c) =
let a' = embed (reduce $ liftCyc Pow a :: Cyc t m zq)
in CT enc k l $ (a' *) <$> c
mulGCT :: (Fact m', CElt t zq)
=> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
mulGCT (CT enc k l c) = CT enc (k+1) l $ mulG <$> c
instance (Eq zp, m `Divides` m', ToSDCtx t m' zp zq)
=> Additive.C (CT m zp (Cyc t m' zq)) where
zero = CT LSD 0 one zero
ct1@(CT enc1 k1 l1 c1) + ct2@(CT enc2 k2 l2 c2)
| l1 /= l2 = error "Cannot add ciphertexts with different scale values"
| k1 < k2 = iterate mulGCT ct1 !! (k2k1) + ct2
| k1 > k2 = ct1 + iterate mulGCT ct2 !! (k1k2)
| enc1 == LSD && enc2 == MSD = toMSD ct1 + ct2
| enc1 == MSD && enc2 == LSD = ct1 + toMSD ct2
| otherwise = CT enc1 k1 l1 $ c1 + c2
negate (CT enc k l c) = CT enc k l $ negate <$> c
instance (ToSDCtx t m' zp zq, Additive (CT m zp (Cyc t m' zq)))
=> Ring.C (CT m zp (Cyc t m' zq)) where
one = CT LSD 0 one one
ct1@(CT MSD _ _ _) * ct2@(CT MSD _ _ _) = toLSD ct1 * ct2
(CT LSD k1 l1 c1) * (CT d2 k2 l2 c2) =
CT d2 (k1+k2+1) (l1*l2) (mulG <$> c1 * c2)
ct1 * ct2 = ct2 * ct1
type AbsorbGCtx t m' zp zq =
(Lift' zp, Reduce (LiftOf zp) zq, Ring zp, Ring zq, Fact m',
CElt t (LiftOf zp), CElt t zp, CElt t zq)
absorbGFactors :: forall t zp zq m m' . (AbsorbGCtx t m' zp zq)
=> CT m zp (Cyc t m' zq) -> CT m zp (Cyc t m' zq)
absorbGFactors ct@(CT enc k l c)
| k == 0 = ct
| k > 0 = let d :: Cyc t m' zp = iterate divG' one !! k
rep = adviseCRT $ reduce $ liftCyc Pow d
in CT enc 0 l $ (rep *) <$> c
| otherwise = error "k < 0 in absorbGFactors"
embedCT :: (CElt t zq,
r `Divides` r', s `Divides` s', r `Divides` s, r' `Divides` s')
=> CT r zp (Cyc t r' zq) -> CT s zp (Cyc t s' zq)
embedCT (CT d 0 l c) = CT d 0 l (embed <$> c)
embedCT _ = error "embedCT requires 0 factors of g; call aborbGFactors first"
embedSK :: (CElt t z, m `Divides` m') => SK (Cyc t m z) -> SK (Cyc t m' z)
embedSK (SK v s) = SK v $ embed s
twaceCT :: (CElt t zq, r `Divides` r', s' `Divides` r',
s ~ (FGCD s' r))
=> CT r zp (Cyc t r' zq) -> CT s zp (Cyc t s' zq)
twaceCT (CT d 0 l c) = CT d 0 l (twace <$> c)
twaceCT _ = error "twaceCT requires 0 factors of g; call absorbGFactors first"
type TunnelCtx t e r s e' r' s' z zp zq gad =
(ExtendLinIdx e r s e' r' s',
e' ~ (e * (r' / r)),
ToSDCtx t r' zp zq,
KSHintCtx gad t r' z zq,
Reduce z zq,
Lift zp z,
CElt t zp,
SwitchCtx gad t s' zq)
tunnelCT :: forall gad t e r s e' r' s' z zp zq rnd .
(TunnelCtx t e r s e' r' s' z zp zq gad,
MonadRandom rnd)
=> Linear t zp e r s
-> SK (Cyc t s' z)
-> SK (Cyc t r' z)
-> TaggedT gad rnd (CT r zp (Cyc t r' zq) -> CT s zp (Cyc t s' zq))
tunnelCT f skout (SK _ sin) = tagT $ (do
let f' = extendLin $ lift f :: Linear t z e' r' s'
f'q = reduce f' :: Linear t zq e' r' s'
ps = proxy powBasis (Proxy::Proxy e')
comps = (evalLin f' . (adviseCRT sin *)) <$> ps
hints :: [Tagged gad [Polynomial (Cyc t s' zq)]] <- CM.mapM (ksHint skout) comps
return $ hints `deepseq` \ct ->
let CT MSD 0 s c = toMSD $ absorbGFactors ct
[c0,c1] = coeffs c
c0' = evalLin f'q c0
c1s = coeffsCyc Pow c1 :: [Cyc t e' zq]
c1s' = zipWith switch hints (embed <$> c1s)
c1' = sum c1s'
in CT MSD 0 s $ P.const c0' + c1')
\\ lcmDivides (Proxy::Proxy r) (Proxy::Proxy e')