module Crypto.Lol.Applications.KeyHomomorphicPRF
(FullBinTree(..), evalTree
,randomTree, balancedTree, leftSpineTree, rightSpineTree
,PRFFamily, makeFamily, randomFamily
,grayCode
,PRFState, prfState
,latticePRF, latticePRFM
,ringPRF, ringPRFM
) where
import Control.Applicative ((<$>))
import Control.DeepSeq
import Control.Monad.Random hiding (fromList)
import Control.Monad.State
import Crypto.Lol
import Data.Bits
import Data.Maybe (fromMaybe)
import MathObj.Matrix
data FullBinTree = L
| I Int FullBinTree FullBinTree
instance NFData FullBinTree where
rnf L = ()
rnf (I i t1 t2) = rnf i `seq` rnf t1 `seq` rnf t2
data PRFFamily gad rq rp =
Params
(Matrix rq)
(Matrix rq)
FullBinTree
instance (NFData rq) => NFData (PRFFamily gad rq rp) where
rnf (Params m1 m2 t) = rnf m1 `seq` rnf m2 `seq` rnf t
makeFamily :: forall rq rp gad . (Gadget gad rq)
=> Matrix rq -> Matrix rq -> FullBinTree -> PRFFamily gad rq rp
makeFamily a0 a1
| numRows a0 /= numRows a1 = error $ "a0 has " ++ show (numRows a0) ++
" rows, but a1 has " ++ show (numRows a1) ++ " rows."
| numColumns a0 /= (numRows a0)*(length $ untag (gadget :: Tagged gad [rq])) =
error $ "Expected " ++ show ((numRows a0)*(length $ untag (gadget :: Tagged gad [rq]))) ++
" columns in a0, but there are " ++ show (numColumns a0) ++ "."
| numColumns a1 /= (numRows a1)*(length $ untag (gadget :: Tagged gad [rq])) =
error $ "Expected " ++ show ((numRows a1)*(length $ untag (gadget :: Tagged gad [rq]))) ++
" columns in a1, but there are " ++ show (numColumns a1) ++ "."
| otherwise = Params a0 a1
data DecoratedTree r =
DL Int (Matrix r)
| DI Int Int (Matrix r) (DecoratedTree r) (Matrix r) (DecoratedTree r)
instance (NFData r) => NFData (DecoratedTree r) where
rnf (DL i m) = rnf i `seq` rnf m
rnf (DI i1 i2 m1 d1 m2 d2) = rnf i1 `seq` rnf i2 `seq` rnf m1 `seq` rnf d1 `seq` rnf m2 `seq` rnf d2
data PRFState rq rp where
PRFState :: (Decompose gad rq)
=> Proxy gad -> Matrix rq -> Matrix rq -> DecoratedTree rq -> PRFState rq rp
instance (NFData rq) => NFData (PRFState rq rp) where
rnf (PRFState Proxy m1 m2 d) = rnf m1 `seq` rnf m2 `seq` rnf d
prfState :: forall gad rq rp . (Decompose gad rq)
=> PRFFamily gad rq rp -> Maybe Int -> PRFState rq rp
prfState p@(Params a0 a1 t) initInput =
let treelen = case t of
L -> 1
(I s _ _) -> s
input = fromMaybe 0 initInput
inputGuard = input >= 0 && input < 2^treelen
pgad = Proxy::Proxy gad
in if inputGuard
then PRFState pgad a0 a1 $ buildDecTree pgad input p
else
error $ "prfState: Input tree has " ++ show treelen ++
" leaves, but input " ++ show input ++ " has " ++
show (logBase 2 (fromIntegral input) :: Double) ++ " bits."
buildDecTree :: (Decompose gad rq)
=> Proxy gad -> Int -> PRFFamily gad rq rp -> DecoratedTree rq
buildDecTree pgad y (Params a0 a1 t) =
let getNumLeaves L = 1
getNumLeaves (I x _ _) = x
go 0 L = (a0, DL 0 a0)
go 1 L = (a1, DL 1 a1)
go x (I numLeaves ltree rtree) =
let numRightLeaves = getNumLeaves rtree
rbits = x .&. ((2^numRightLeaves)1)
lbits = shift x (numRightLeaves)
(lval, ltree') = go lbits ltree
(rval, rtree') = go rbits rtree
decompr = fmap reduce $ proxy (decomposeMatrix rval) pgad
val = lval * decompr
in (val, DI numLeaves x val ltree' decompr rtree')
in snd $ go y t
evalTree :: Int -> PRFState rq rp -> (Matrix rq, PRFState rq rp)
evalTree y (PRFState pgad a0 a1 t) =
let getNumLeaves (DL _ _) = 1
getNumLeaves (DI i _ _ _ _ _) = i
go 0 (DL _ _) = (a0, DL 0 a0, False)
go 1 (DL _ _) = (a1, DL 1 a1, False)
go i n@(DI numLeaves x val ltree decompr rtree)
| i == x = (val,n, False)
| otherwise =
let numRightLeaves = getNumLeaves rtree
rbits = x .&. ((2^numRightLeaves)1)
lbits = shift x (numRightLeaves)
(lval, ltree', _) = go lbits ltree
(rval, rtree', changed) = go rbits rtree
decompr' = if changed
then fmap reduce $ proxy (decomposeMatrix rval) pgad
else decompr
val' = lval * decompr'
in (val', DI numLeaves i val' ltree' decompr' rtree', True)
(res, t', _) = go y t
in (res, PRFState pgad a0 a1 t')
latticePRF' :: (Rescale zq zp)
=> Matrix zq -> Int -> PRFState zq zp -> (Matrix zp, PRFState zq zp)
latticePRF' s x state1@(PRFState _ a0 _ _)
| numRows s /= 1 = error "Secret key must have one row."
| numColumns s /= numRows a0 = error $ "Secret key has " ++
show (numColumns s) ++ " columns, but a0 has " ++
show (numRows a0) ++ " rows."
| otherwise = let (res,state2) = evalTree x state1
in (rescale <$> s*res, state2)
latticePRF :: (Rescale zq zp)
=> Matrix zq -> Int -> PRFState zq zp -> Matrix zp
latticePRF s x = fst. latticePRF' s x
latticePRFM :: (MonadState (PRFState zq zp) mon, Rescale zq zp)
=> Matrix zq -> Int -> mon (Matrix zp)
latticePRFM s x = state $ latticePRF' s x
ringPRF' :: (Fact m, RescaleCyc (Cyc t) zq zp, Ring rq,
rq ~ Cyc t m zq, rp ~ Cyc t m zp)
=> rq -> Int -> PRFState rq rp -> (Matrix rp, PRFState rq rp)
ringPRF' s x state1 =
let (res,state2) = evalTree x state1
in ((rescaleDec . (s*)) <$> res, state2)
ringPRF :: (Fact m, RescaleCyc (Cyc t) zq zp, Ring rq,
rq ~ Cyc t m zq, rp ~ Cyc t m zp)
=> rq -> Int -> PRFState rq rp -> Matrix rp
ringPRF s x = fst . ringPRF' s x
ringPRFM :: (MonadState (PRFState rq rp) mon, Fact m,
RescaleCyc (Cyc t) zq zp, Ring rq,
rq ~ Cyc t m zq, rp ~ Cyc t m zp)
=> rq -> Int -> mon (Matrix rp)
ringPRFM s x = state $ ringPRF' s x
randomTree :: (MonadRandom rnd) => Int -> rnd FullBinTree
randomTree 1 = return L
randomTree i = do
leftSize <- getRandomR (1,i1)
left <- randomTree leftSize
right <- randomTree $ ileftSize
return $ I i left right
leftSpineTree :: Int -> FullBinTree
leftSpineTree 1 = L
leftSpineTree i = I i (leftSpineTree $ i1) L
rightSpineTree :: Int -> FullBinTree
rightSpineTree 1 = L
rightSpineTree i = I i L (rightSpineTree $ i1)
balancedTree :: Int -> FullBinTree
balancedTree 1 = L
balancedTree i =
let lastFullLevelSize = 2^(floor (logBase 2 (fromIntegral i) :: Double) :: Int)
lsize = min lastFullLevelSize $ i(lastFullLevelSize `div` 2)
rsize = ilsize
in I i (balancedTree lsize) (balancedTree rsize)
randomFamily :: forall gad rnd rq rp . (MonadRandom rnd, Random rq, Gadget gad rq)
=> Int -> rnd (PRFFamily gad rq rp)
randomFamily size = do
t <- randomTree size
let len = length $ untag (gadget :: Tagged gad [rq])
a0 <- fromList 1 len <$> take len <$> getRandoms
a1 <- fromList 1 len <$> take len <$> getRandoms
return $ makeFamily a0 a1 t
grayCode :: Int -> [Int]
grayCode 1 = [0,1]
grayCode n =
let gc' = grayCode (n1)
pow2 = 2^(n1)
rightHalf = map (+pow2) $ reverse gc'
in gc' ++ rightHalf