{- | Module : Data.NTRU Description : NTRU cryptographic system implementation Maintainer : tlevine@cyberpointllc.com Stability : Experimental License : MIT This is an implementation of the NTRU cryptographic system, following the standard set forth by the IEEE in the document entitled IEEE Standard Specification for Public Key Cryptographic Techniques Based on Hard Problems over Lattices. It is designed to be compatible with the implmentation of SecurityInnovations, available . -} module Data.NTRU (keyGen112, keyGen128, keyGen192, keyGen256, encrypt112, encrypt128, encrypt192, encrypt256, decrypt112, decrypt128, decrypt192, decrypt256) where import Data.Digest.Pure.SHA import Data.List.Split import Data.Sequence as Seq (index, update, empty, fromList, Seq) import Data.Foldable as L (toList) import Crypto.Random import System.Random import Math.Polynomial import Math.NumberTheory.Moduli import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC import qualified Data.ByteString.Lazy as BL {- Polynomial Operations -} -- | Poly to List fromPoly :: (Num a, Eq a, Integral a) => Poly a -> [a] fromPoly = polyCoeffs LE -- | List to Poly toPoly :: (Num a, Eq a, Integral a) => [a] -> Poly a toPoly = poly LE -- | Retrive the coefficient of p corresponding to the (x^i) term polyCoef :: (Num a, Eq a, Integral a) => Poly a -> Int -> a polyCoef p i = fromPoly p !! i -- | Useful for syntax. Allows for poly + poly or poly * poly. -- | Note that for ring multiplication, reduceDegree must be called instance (Num a, Eq a) => Num (Poly a) where f + g = addPoly f g f * g = multPoly f g negate = negatePoly abs = undefined signum = undefined fromInteger = undefined -- | Allows for polynomial multiplaction in the ring of size n: reduceDegree (getDegree a) (a * b) = a * b in the ring reduceDegree :: (Num a, Eq a, Integral a) => Int -> Poly a -> Poly a reduceDegree n f = let (f1,f2) = splitAt n (fromPoly f) in toPoly f1 + toPoly f2 -- | Reduces all of the polynomial's coefficents mod q polyMod :: (Num a, Eq a, Integral a) => a -> Poly a -> Poly a polyMod q f = toPoly $ map (`mod` q) (fromPoly f) -- | Same as polyMod, but chooses representative group values in Z/nZ to be in [-q/2, q/2] instead of [0,q-1] polyModInterval :: (Num a, Eq a, Integral a) => a -> Poly a -> Poly a polyModInterval q f = toPoly $ map (\x -> intervalReduce $ x `mod` q) (fromPoly f) where intervalReduce x = if x <= (q `div` 2) then x else x - q -- | PolyMod when q is big polyBigMod :: (Num a, Eq a, Integral a) => Int -> Poly a -> Poly a polyBigMod q p = toPoly $ map fromIntegral $ fromPoly $ polyMod q $ toPoly $ map fromIntegral $ fromPoly p -- | Creates the polynomial x^n xPow :: (Num a, Eq a, Integral a) => Int -> Poly a xPow = powPoly x {- Key Generation -} -- | 6.3.3.1 Divides one polynomial by another mod p: let (q,r) = divPolyMod p a b; ((b * q) + r) `mod` p = a divPolyMod :: (Num a, Eq a, Integral a) => a -> Poly a -> Poly a -> (Poly a, Poly a) divPolyMod p a b = let n = polyDegree b in let u = inverseMod (polyCoef b n) p in divLoop p b n u zero a where divLoop p b n u q r = let d = polyDegree r in if d < n then (polyMod p q, polyMod p r) else let v = scalePoly (u * polyCoef r d) (xPow (d - n)) in let r' = polyMod p $ r - (v * b) in let q' = polyMod p $ q + v in divLoop p b n u q' r' -- | 6.3.3.2 Finds the extended GCD mod p: let (d,u) = extendedEuclidean p a b; if d == 1, then (u * a) `mod` p = 1 extendedEuclidean :: (Num a, Eq a, Integral a) => a -> Poly a -> Poly a -> (Poly a, Poly a) extendedEuclidean p a b = extendedEuclideanLoop p one a zero b where extendedEuclideanLoop p u d v1 v3 | polyIsZero v3 = (d,u) | otherwise = let (q,t3) = divPolyMod p d v3 in let t1 = polyMod p $ u - q * v1 in extendedEuclideanLoop p v1 v3 t1 t3 -- | Generates Polynomials and Attempts to Find Inverses Until Success: let (a,u) = findInversable params; (a * u) `mod` 2 = 1 findInversable :: (Num a, Eq a, Integral a) => [Int] -> IO (Poly a, Poly a) findInversable params = do let n = getN params let df = getDf params a' <- genRandPoly n df df let a = scalePoly (getP params) a' + one let b = xPow n - one let (d, u) = extendedEuclidean 2 a b if d == one then return (a, u) else findInversable params -- | 6.3.3.4 Raises Polynomial Inverse mod 2 to mod 2^11; let (a, u) = findInversable; (a * (inverseLift a b (degree a))) `mod` 2048 = 1 inverseLift :: (Num a, Eq a, Integral a) => Poly a -> Poly a -> Int -> a -> Poly a inverseLift a b deg = inverseLift' a b deg 2 11 where inverseLift' a b deg n e q | e == 0 = polyMod (2 ^ 11) b | otherwise = let b' = polyBigMod (2 ^ n) $ scalePoly 2 b - (reduceDegree deg $! a * (reduceDegree deg $! (b * b))) in inverseLift' a b' deg (2 * n) (e `div` 2) q -- | 9.2.1 Generates a key pair. (publicKey, privateKey). The private key will be 1 + pF, per enhancement 2 at -- | https://www.securityinnovation.com/uploads/Crypto/NTRU%20Enhancements%201.pdf generateKeyPair :: (Num a, Eq a, Integral a) => [Int] -> IO ([a], [a]) generateKeyPair params = do let n = getN params dg = getDg params q = getQ params (f, u) <- findInversable params let fq = inverseLift f u n (fromIntegral q) g <- genRandPoly n dg (dg - 1) let pk = polyMod q $! reduceDegree n $! scalePoly (getP params) $! fq * g return (fromPoly pk, fromPoly f) {- Blinding Polynomial Generation -} -- | Creates seed for bpgm genSData :: (Num a, Eq a, Integral a) => [a] -> [a] -> [a] -> [Int] -> [a] genSData h msg b params = let bh = concatMap bigIntToBits h in let pkLen = getPkLen params in let bhTrunc = take (pkLen - (pkLen `mod` 8)) bh in let hTrunc = map (fromIntegral . bitsToInt) (chunksOf 8 bhTrunc) in let sData = map fromIntegral (getOID params) ++ msg ++ b ++ hTrunc in sData -- | 8.3.2.2 Generates the blinding polynomial using the given seed bpgm :: (Num a, Eq a, Integral a) => [a] -> [Int] -> [a] bpgm seed params = let (i, s) = igf ([], [], 0) seed params in let r = Seq.update i 1 $ Seq.fromList $ replicate (getN params) 0 in let t = getDr params in let r' = rlooper s 1 r (t - 1) params in L.toList $ rlooper s (-1) r' t params -- | Creates the sequence with the proper -1's and 1's rlooper :: (Num a, Eq a, Integral a) => ([a], [a], a) -> a -> Seq.Seq a -> Int -> [Int] -> Seq.Seq a rlooper s val r 0 params = r rlooper s val r t params = let (i, s') = igf s [] params in if Seq.index r i == 0 then (let r' = Seq.update i val r in rlooper s' val r' (t-1) params) else rlooper s' val r t params -- | 8.4.2.1 Given a state or a seed, generates the next index to be used igf :: (Num a, Eq a, Integral a) => ([a], [a], a) -> [a] -> [Int] -> (Int, ([a], [a], a)) igf state seed params = let (z, buf, counter) = extractVariables state seed params (i, buf', counter') = genIndex counter buf z params s = (z, buf', counter') n = getN params in (i `mod` n, s) -- | Either initializes the state, or uses the already created one extractVariables :: (Num a, Eq a, Integral a) => ([a], [a], a) -> [a] -> [Int] -> ([a], [a], a) extractVariables state [] _ = state extractVariables _ seed params = igfinit seed params -- | Initialization of state igfinit :: (Num a, Eq a, Integral a) => [a] -> [Int] -> ([a], [a], a) igfinit seed params = let minCallsR = getMinCallsR params shaFn = getSHA params z = shaFn seed buf = buildM 0 minCallsR z shaFn [] in (z, buf, minCallsR) -- | Returns an index and pieces of the state genIndex :: (Num a, Eq a, Integral a) => a -> [a] -> [a] -> [Int] -> (Int, [a], a) genIndex counter buf z params = let remLen = length buf c = getC params n = getN params shaFn = getSHA params hLen = getHLen params tmpLen = (c - remLen) cThreshold = counter + fromIntegral (ceiling (fromIntegral tmpLen / fromIntegral hLen)) (m, counter') = if remLen >= c then (buf, counter) else (buildM counter cThreshold z shaFn buf, cThreshold) (b, buf') = splitAt c (buf ++ m) i = fromIntegral $ bitsToInt b in if i >= (2^c - (2^c `mod` n)) then genIndex counter' buf' z params else (i, buf', counter') -- | Builds out the buffer buildM :: (Num a, Eq a, Integral a) => a -> a -> [a] -> ([a]->[a]) -> [a] -> [a] buildM count cThreshold z shaFn buf | count >= cThreshold = buf | otherwise = let c = i2osp count 3 h = shaFn (z ++ c) m = buf ++ intsToBits h in buildM (count + 1) cThreshold z shaFn m -- | Converts counter to 4 bytes... Not exactly the same as documentation but in practice counter does not exceed the bounds i2osp :: (Num a, Eq a, Integral a) => a -> a -> [a] i2osp i n | n == 0 = [i] | otherwise = 0:i2osp i (n-1) {- SHA Functionality -} -- | Needed to pass sha() output to unpack() bToStrict :: BL.ByteString -> B.ByteString bToStrict = B.concat . BL.toChunks -- | sha1 output: 20 octets (1 octet = 8 bits) sha1Octets :: (Num a, Eq a, Integral a) => [a] -> [a] sha1Octets input = map fromIntegral $ B.unpack $ bToStrict $ bytestringDigest $ sha1 $ BL.pack $ map fromIntegral input -- | sha256 output: 32 octets sha256Octets :: (Num a, Eq a, Integral a) => [a] -> [a] sha256Octets input = map fromIntegral $ B.unpack $ bToStrict $ bytestringDigest $ sha256 $ BL.pack $ map fromIntegral input {- Mask Generation -} -- Much of this code is similar to blinding polynomial generation, but we implemented separately -- | 8.4.1.1 Generates the mask based on the given seed mgf :: (Num a, Eq a, Integral a) => [a] -> [Int] -> [a] mgf seed params = let n = getN params in let shaFn = getSHA params in let z = shaFn seed in let buf = buildBuffer 0 (getMinCallsR params) z shaFn [] in let i = formatI buf in take n $ finishI i n (getMinCallsR params) z shaFn -- | Builds out the buffer buildBuffer :: (Num a, Eq a, Integral a) => a -> a -> [a] -> ([a]->[a]) -> [a] -> [a] buildBuffer counter minCallsR z shaFn buffer | counter >= minCallsR = buffer | otherwise = let octet_c = i2osp counter 3 in let h = shaFn (z ++ octet_c) in buildBuffer (counter + 1) minCallsR z shaFn (buffer ++ h) -- | Step I Converts octets to trits toTrits :: (Num a, Eq a, Integral a) => a -> a -> [a] toTrits n o | n == 0 = [] | otherwise = (o `mod` 3):toTrits (n - 1) ((o - (o `mod` 3)) `div` 3) -- | Builds out buffer when needed finishI :: (Num a, Eq a, Integral a) => [a] -> Int -> a -> [a] -> ([a] -> [a]) -> [a] finishI i n counter z shaFn | fromIntegral (length i) >= n = i | otherwise = let buf = buildBuffer counter (counter + 1) z shaFn [] in let i' = formatI buf in finishI i' n (counter + 1) z shaFn -- | Formats buffer formatI :: (Num a, Eq a, Integral a) => [a] -> [a] formatI buf = concatMap (toTrits 5) $ filter (< 243) buf {- Encrypt -} -- | 9.2.2 Encrypts msg using the public key h and parameter set encrypt :: (Num a, Eq a, Integral a) => [Int] -> [a] -> [a] -> IO [a] encrypt params msg h = let l = fromIntegral $ length msg maxLength = getMaxMsgLenBytes params in if l > maxLength then error "message too long" else do let bLen = getDb params `div` 8 dr = getDr params n = getN params q = getQ params p = getP params b <- randByteString bLen let p0 = replicate (fromIntegral $ maxLength - l) 0 m = b ++ [fromIntegral l] ++ msg ++ p0 mBin = addPadding $ intsToBits m mTrin = concatMap binToTern $ chunksOf 3 mBin sData = genSData h msg b params r = bpgm sData params r' = polyMod q $ reduceDegree n $ toPoly r * toPoly h r4 = polyMod 4 r' or4 = toOctets $ fromPoly r4 mask = mgf or4 params m' = polyModInterval p $ toPoly mask + toPoly mTrin e = polyMod q $ r' + m' return $ fromPoly e {- Decrypt -} -- | 9.3.3 Decrypts e using the private key f and verifies it using the public key h. decrypt :: (Num a, Eq a, Integral a) => [Int] -> [a] -> [a] -> [a] -> [a] decrypt params f h e = let n = getN params p = getP params q = getQ params bLen = getDb params `div` 8 ci = polyMod p $ polyModInterval q $ reduceDegree n $ toPoly f * toPoly e cR = polyMod q $ toPoly e - polyModInterval p ci cR4 = polyMod 4 cR coR4 = toOctets $ fromPoly cR4 cMask = polyMod p $ toPoly $ mgf coR4 params cMTrin = polyModInterval p $ ci - cMask cMTrin' = improperPolynomial n $ fromPoly cMTrin cMBin = concatMap ternToBin $ chunksOf 2 $ take (length cMTrin' - (length cMTrin' `mod` 2)) cMTrin' cM = map bitsToInt $ chunksOf 8 $ take (length cMBin - (length cMBin `mod` 8)) cMBin (cb, rest) = splitAt bLen cM ([cl], rest') = splitAt (getLLen params) rest (cm, rest'') = splitAt (fromIntegral cl) rest' sData = genSData h cm cb params cr = bpgm sData params cR' = polyMod q $ reduceDegree n $ toPoly cr * toPoly h validR = cR' == cR validRemainder = all (==0) rest'' in checkValid cm validR validRemainder -- | Checks results of verification steps checkValid :: (Num a, Eq a, Integral a) => [a] -> Bool -> Bool -> [a] checkValid _ _ False = error "Failure Checking Remainder of Message" checkValid _ False _ = error "Failure Verifying Blinding Polynomial" checkValid m _ _ = m {- Other Operations -} -- | Calculate the modular inverse of x and y: ((inverseMod x y) * x) `mod` y = 1 inverseMod :: (Num a, Eq a, Integral a) => a -> a -> a inverseMod x y = case invertMod (fromIntegral x) (fromIntegral y) of Just n -> fromIntegral n _ -> error "Coukd not calculate inverseMod" -- | Generate a random ByteString randByteString :: (Num a, Eq a, Integral a) => Int -> IO [a] randByteString size = do g <- newGenIO :: IO SystemRandom case genBytes size g of Left err -> error $ show err Right (result, g2) -> return (unpackByteString result) -- | Converts a bytestring to a list of ascii values unpackByteString :: (Num a, Eq a, Integral a) => BC.ByteString -> [a] unpackByteString str = map fromIntegral (B.unpack str) -- | Used to encode bits of a message from binary to trinary representation binToTern :: (Num a, Eq a, Integral a) => [a] -> [a] binToTern [0,0,0] = [0,0] binToTern [0,0,1] = [0,1] binToTern [0,1,0] = [0,-1] binToTern [0,1,1] = [1,0] binToTern [1,0,0] = [1,1] binToTern [1,0,1] = [1,-1] binToTern [1,1,0] = [-1,0] binToTern [1,1,1] = [-1,1] binToTern _ = error "Problem converting binary to trinary" -- | Inverse of binToTern ternToBin :: (Num a, Eq a, Integral a) => [a] -> [a] ternToBin [0,0] = [0,0,0] ternToBin [0,1] = [0,0,1] ternToBin [0,-1] = [0,1,0] ternToBin [1,0] = [0,1,1] ternToBin [1,1] = [1,0,0] ternToBin [1,-1] = [1,0,1] ternToBin [-1,0] = [1,1,0] ternToBin [-1,1] = [1,1,1] ternToBin _ = error " Problem converting trinary to binary" -- | Makes message length a multiple of 3 by padding with 0s addPadding :: (Num a, Eq a, Integral a) => [a] -> [a] addPadding m = case length m `mod` 3 of 0 -> m 1 -> m ++ [0,0] 2 -> m ++ [0] -- | Converts a single byte to a list of (n+1) bits: unpackByte 7 3 = [0,0,0,0,0,0,1,1] unpackByte :: (Num a, Eq a, Integral a) => a -> a -> [a] unpackByte n b | n < 0 = [] | otherwise = (b `div` (2 ^ n)):unpackByte (n-1) (b `mod` 2 ^ n) -- | Converts a byte to a list of 8 bits intToBits :: (Num a, Eq a, Integral a) => a -> [a] intToBits = unpackByte 7 -- | Converts a byte to a list of 11 bits. Needed for blinding polynomial seed bigIntToBits :: (Num a, Eq a, Integral a) => a -> [a] bigIntToBits = unpackByte 10 -- | Turns a list of integers into bits intsToBits :: (Num a, Eq a, Integral a) => [a] -> [a] intsToBits = concatMap intToBits -- | Converts a list of bits to a single byte: bitsToInt [0,0,0,0,0,0,1,1] = 3 bitsToInt :: (Num a, Eq a, Integral a) => [a] -> a bitsToInt b = packByte 1 (reverse b) where packByte n b | null b = 0 | otherwise = n * head b + packByte (n * 2) (tail b) -- | Generates a random polynomial of degree n with pos 1's and neg -1's genRandPoly :: (Num a, Eq a, Integral a) => Int -> Int -> Int -> IO (Poly a) genRandPoly n pos neg = do poly <- setRandValues [] n pos neg return $ toPoly poly where setRandValues lst n pos neg = if n == 0 then return lst else do randVal <- randomIO :: IO Int let randInRange = randVal `mod` n if randInRange <= pos then setRandValues ((-1):lst) (n - 1) (neg - 1) pos else if randInRange <= (pos + neg) then setRandValues (1:lst) (n - 1) neg (pos - 1) else setRandValues (0:lst) (n - 1) neg pos -- | Creates an improper polynomial of length n from poly improperPolynomial :: (Num a, Eq a, Integral a) => Int -> [a] -> [a] improperPolynomial n poly = poly ++ replicate (fromIntegral n - length poly) 0 -- | Pads the given list with the requisite zeros to have a multiple of 8 length padInt8 :: (Num a, Eq a, Integral a) => [a] -> [a] padInt8 lst = lst ++ replicate ((8 - (length lst `mod` 8)) `mod` 8) 0 -- | Converts to octets toOctets :: (Num a, Eq a, Integral a) => [a] -> [a] toOctets lst = let int2s = concatMap (reverse . take 2 . reverse . unpackByte 7) lst in map (bitsToInt . padInt8) $ chunksOf 8 int2s {- Paramter Sets -} -- | Generates the proper parameter set based on the given bit level genParams :: (Num a, Eq a, Integral a) => a -> [Int] genParams bit_level | bit_level == 112 = [401,3,2048,113,133,1,112,60,600,400,113,1,113,11,32,0,0,2,4,114,112] | bit_level == 128 = [449,3,2048,134,149,1,128,67,672,448,134,1,134,9,31,9,0,3,3,128,128] | bit_level == 192 = [677,3,2048,157,225,1,192,101,1008,676,157,256,157,11,27,9,0,5,3,192,192] | bit_level == 256 = [1087,3,2048,120,362,1,256,170,1624,1086,120,256,120,13,25,14,0,6,3,256,256] | otherwise = error "BitLevel must be 112, 128, 192, 256" -- | Parsing functions for paramter set getN = head getP lst = fromIntegral $ lst!!1 getQ lst = fromIntegral $ lst!!2 getDf lst = lst!!3 getDg lst = lst!!4 getLLen lst = lst!!5 getDb lst = lst!!6 getMaxMsgLenBytes lst = lst!!7 getBufferLenBits lst = lst!!8 getBufferLenTrits lst = lst!!9 getDm0 lst = lst!!10 getSHA lst | lst!!11 == 1 = sha1Octets | otherwise = sha256Octets getHLen lst | lst!!11 == 1 = 20 | otherwise = 32 getDr lst = lst!!12 getC lst = lst!!13 getMinCallsR lst = fromIntegral $ lst!!14 getMinCallsMask lst = fromIntegral $ lst!!15 getOID lst = [lst!!16,lst!!17,lst!!18] getPkLen lst = lst!!19 getLvl lst = lst!!20 {- External Functions -} -- | Generates a key-pair with the EES401EP1 Parameter Set keyGen112 :: (Num a, Eq a, Integral a) => IO ([a], [a]) -- ^ A tuple representing (PublicKey, PrivateKey) where PrivateKey = 1 + pf, per Enhancement #2 at https://www.securityinnovation.com/uploads/Crypto/NTRU%20Enhancements%201.pdf keyGen112 = generateKeyPair (genParams 112) -- | Generates a key-pair with the EES449EP1 Parameter Set keyGen128 :: (Num a, Eq a, Integral a) => IO ([a], [a]) keyGen128 = generateKeyPair (genParams 128) -- | Generates a key-pair with the EES677EP1 Parameter Set keyGen192 :: (Num a, Eq a, Integral a) => IO ([a], [a]) keyGen192 = generateKeyPair (genParams 192) -- | Generates a key-pair with the EES1087EP2 Parameter Set keyGen256 :: (Num a, Eq a, Integral a) => IO ([a], [a]) keyGen256 = generateKeyPair (genParams 256) -- | Encrypts a message with the EES401EP1 Parameter Set encrypt112 :: (Num a, Eq a, Integral a) => [a] -- ^ A list of ASCII values representing the message -> [a] -- ^ A list of numbers representing the public key -> IO [a] -- ^ A list of numbers representing the ciphertext encrypt112 = encrypt (genParams 112) -- | Encrypts a message with the EES449EP1 Parameter Set encrypt128 :: (Num a, Eq a, Integral a) => [a] -> [a] -> IO [a] encrypt128 = encrypt (genParams 128) -- | Encrypts a message with the EES677EP1 Parameter Set encrypt192 :: (Num a, Eq a, Integral a) => [a] -> [a] -> IO [a] encrypt192 = encrypt (genParams 192) -- | Encrypts a message with the EES1087EP2 Parameter Set encrypt256 :: (Num a, Eq a, Integral a) => [a] -> [a] -> IO [a] encrypt256 = encrypt (genParams 256) -- | Decrypts and verifies a cyphertext with the EES401EP1 Parameter Set decrypt112 :: (Num a, Eq a, Integral a) => [a] -- ^ A list of numbers representing the private key -> [a] -- ^ A list of numbers representing the public key -> [a] -- ^ A list of numbers representing the ciphertext -> [a] -- ^ A list of numbers representing the original message decrypt112 = decrypt (genParams 112) -- | Decrypts and verifies a cyphertext with the EES449EP1 Parameter Set decrypt128 :: (Num a, Eq a, Integral a) => [a] -> [a] -> [a] -> [a] decrypt128 = decrypt (genParams 128) -- | Decrypts and verifies a cyphertext with the EES677EP1 Parameter Set decrypt192 :: (Num a, Eq a, Integral a) => [a] -> [a] -> [a] -> [a] decrypt192 = decrypt (genParams 192) -- | Decrypts and verifies a cyphertext with the EES1087EP2 Parameter Set decrypt256 :: (Num a, Eq a, Integral a) => [a] -> [a] -> [a] -> [a] decrypt256 = decrypt (genParams 256)