module Math.NTRU (keyGen, encrypt, decrypt, genParams, ParamSet(..)) where
import Data.Digest.Pure.SHA
import Data.List.Split
import Data.Sequence as Seq (index, update, fromList, Seq)
import Data.Foldable as L (toList)
import Crypto.Random
import System.Random
import Math.Polynomial
import Math.NumberTheory.Moduli
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as BL
fromPoly :: Poly Integer -> [Integer]
fromPoly = polyCoeffs LE
toPoly :: [Integer] -> Poly Integer
toPoly = poly LE
polyCoef :: Poly Integer -> Int -> Integer
polyCoef p i = fromPoly p !! i
instance (Num a, Eq a) => Num (Poly a) where
f + g = addPoly f g
f * g = multPoly f g
negate = negatePoly
abs = undefined
signum = undefined
fromInteger = undefined
reduceDegree :: Int -> Poly Integer -> Poly Integer
reduceDegree n f =
let (f1,f2) = splitAt n (fromPoly f)
in toPoly f1 + toPoly f2
polyMod :: Integer -> Poly Integer -> Poly Integer
polyMod q f = toPoly $ map (`mod` q) (fromPoly f)
polyModInterval :: Integer -> Poly Integer -> Poly Integer
polyModInterval q f = toPoly $ map (\x' -> intervalReduce $ x' `mod` q) (fromPoly f)
where intervalReduce x' = if x' <= (q `div` 2) then x' else x' q
polyBigMod :: Integer -> Poly Integer -> Poly Integer
polyBigMod q p = toPoly $ map fromIntegral $ fromPoly $ polyMod q $ toPoly $ map fromIntegral $ fromPoly p
xPow :: Int -> Poly Integer
xPow = powPoly x
divPolyMod :: Integer -> Poly Integer -> Poly Integer -> (Poly Integer, Poly Integer)
divPolyMod p a b =
let n = polyDegree b in
let u = inverseMod (polyCoef b n) p in
divLoop p b n u zero a
where
divLoop p' b' n' u' q r =
let d = polyDegree r in
if d < n' then (polyMod p' q, polyMod p r)
else
let v = scalePoly (u' * polyCoef r d) (xPow (d n')) in
let r' = polyMod p' $ r (v * b') in
let q' = polyMod p' $ q + v in
divLoop p' b' n' u' q' r'
extendedEuclidean :: Integer -> Poly Integer -> Poly Integer -> (Poly Integer, Poly Integer)
extendedEuclidean p a b = extendedEuclideanLoop p one a zero b
where
extendedEuclideanLoop p' u d v1 v3
| polyIsZero v3 = (d,u)
| otherwise =
let (q,t3) = divPolyMod p' d v3 in
let t1 = polyMod p' $ u q * v1 in
extendedEuclideanLoop p' v1 v3 t1 t3
findInvertible :: ParamSet -> IO (Poly Integer, Poly Integer)
findInvertible params = do
let n = getN params
let df = getDf params
a' <- genRandPoly n df df
let a = scalePoly (getP params) a' + one
let b = xPow n one
let (d, u) = extendedEuclidean 2 a b
if d == one then return (a, u) else findInvertible params
inverseLift :: Poly Integer -> Poly Integer -> Int -> Integer -> Poly Integer
inverseLift a b deg = inverseLift' a b deg (2 :: Integer) (11 :: Integer) where
inverseLift' a b deg n e q
| e == 0 = polyMod (2 ^ q) b
| otherwise =
let b' = polyBigMod (2 ^ n) $ scalePoly 2 b (reduceDegree deg $! a * (reduceDegree deg $! (b * b)))
in inverseLift' a b' deg (2 * n) (e `div` 2) q
keyGen :: ParamSet
-> IO ([Integer], [Integer])
keyGen params = do
let n = getN params
dg = getDg params
q = getQ params
(f, u) <- findInvertible params
let fq = inverseLift f u n (fromIntegral q)
g <- genRandPoly n dg (dg 1)
let pk = polyMod q $! reduceDegree n $! scalePoly (getP params) $! fq * g
return (fromPoly pk, fromPoly f)
genSData :: [Integer] -> [Integer] -> [Integer] -> ParamSet -> [Integer]
genSData h msg b params =
let bh = concatMap bigIntToBits h in
let pkLen = getPkLen params in
let bhTrunc = take (pkLen (pkLen `mod` 8)) bh in
let hTrunc = map (fromIntegral . bitsToInt) (chunksOf 8 bhTrunc) in
let sData = map fromIntegral (getOID params) ++ msg ++ b ++ hTrunc in
sData
bpgm :: [Integer] -> ParamSet -> [Integer]
bpgm seed params =
let (i, s) = igf ([], [], 0) seed params in
let r = Seq.update i 1 $ Seq.fromList $ replicate (getN params) 0 in
let t = getDr params in
let r' = rlooper s 1 r (t 1) params in
L.toList $ rlooper s (1) r' t params
rlooper :: ([Integer], [Integer], Integer) -> Integer -> Seq.Seq Integer -> Int -> ParamSet -> Seq.Seq Integer
rlooper s val r 0 params = r
rlooper s val r t params =
let (i, s') = igf s [] params in
if Seq.index r i == 0
then (let r' = Seq.update i val r in rlooper s' val r' (t1) params)
else rlooper s' val r t params
igf :: ([Integer], [Integer], Integer) -> [Integer] -> ParamSet -> (Int, ([Integer], [Integer], Integer))
igf state seed params =
let (z, buf, counter) = extractVariables state seed params
(i, buf', counter') = genIndex counter buf z params
s = (z, buf', counter')
n = getN params
in (i `mod` n, s)
extractVariables :: ([Integer], [Integer], Integer) -> [Integer] -> ParamSet -> ([Integer], [Integer], Integer)
extractVariables state [] _ = state
extractVariables _ seed params = igfinit seed params
igfinit :: [Integer] -> ParamSet -> ([Integer], [Integer], Integer)
igfinit seed params =
let minCallsR = getMinCallsR params
shaFn = getSHA params
z = shaFn seed
buf = buildM 0 minCallsR z shaFn []
in (z, buf, minCallsR)
genIndex :: Integer -> [Integer] -> [Integer] -> ParamSet -> (Int, [Integer], Integer)
genIndex counter buf z params =
let remLen = length buf
c = getC params
n = getN params
shaFn = getSHA params
hLen = getHLen params
tmpLen = (c remLen)
cThreshold = counter + fromIntegral (ceiling (fromIntegral tmpLen / fromIntegral hLen))
(m, counter') = if remLen >= c
then (buf, counter)
else (buildM counter cThreshold z shaFn buf, cThreshold)
(b, buf') = splitAt c (buf ++ m)
i = fromIntegral $ bitsToInt b
in if i >= (2^c (2^c `mod` n))
then genIndex counter' buf' z params
else (i, buf', counter')
buildM :: Integer -> Integer -> [Integer] -> ([Integer]->[Integer]) -> [Integer] -> [Integer]
buildM count cThreshold z shaFn buf
| count >= cThreshold = buf
| otherwise =
let c = i2osp count 3
h = shaFn (z ++ c)
m = buf ++ intsToBits h
in buildM (count + 1) cThreshold z shaFn m
i2osp :: Integer -> Integer -> [Integer]
i2osp i n
| n == 0 = [i]
| otherwise = 0:i2osp i (n1)
bToStrict :: BL.ByteString -> B.ByteString
bToStrict = B.concat . BL.toChunks
sha1Octets :: [Integer] -> [Integer]
sha1Octets input = map fromIntegral $ B.unpack $ bToStrict $ bytestringDigest $ sha1 $ BL.pack $ map fromIntegral input
sha256Octets :: [Integer] -> [Integer]
sha256Octets input = map fromIntegral $ B.unpack $ bToStrict $ bytestringDigest $ sha256 $ BL.pack $ map fromIntegral input
mgf :: [Integer] -> ParamSet -> [Integer]
mgf seed params =
let n = getN params in
let shaFn = getSHA params in
let z = shaFn seed in
let buf = buildBuffer 0 (getMinCallsR params) z shaFn [] in
let i = formatI buf in
take n $ finishI i n (getMinCallsR params) z shaFn
buildBuffer :: Integer -> Integer -> [Integer] -> ([Integer]->[Integer]) -> [Integer] -> [Integer]
buildBuffer counter minCallsR z shaFn buffer
| counter >= minCallsR = buffer
| otherwise = let octet_c = i2osp counter 3 in
let h = shaFn (z ++ octet_c) in
buildBuffer (counter + 1) minCallsR z shaFn (buffer ++ h)
toTrits :: Integer -> Integer -> [Integer]
toTrits n o
| n == 0 = []
| otherwise = (o `mod` 3):toTrits (n 1) ((o (o `mod` 3)) `div` 3)
finishI :: [Integer] -> Int -> Integer -> [Integer] -> ([Integer] -> [Integer]) -> [Integer]
finishI i n counter z shaFn
| fromIntegral (length i) >= n = i
| otherwise = let buf = buildBuffer counter (counter + 1) z shaFn [] in
let i' = formatI buf in
finishI i' n (counter + 1) z shaFn
formatI :: [Integer] -> [Integer]
formatI buf = concatMap (toTrits 5) $ filter (< 243) buf
encrypt :: ParamSet
-> [Integer]
-> [Integer]
-> IO [Integer]
encrypt params msg h =
let l = fromIntegral $ length msg
maxLength = getMaxMsgLenBytes params in
if l > maxLength then error "message too long"
else do
let bLen = getDb params `div` 8
dr = getDr params
n = getN params
q = getQ params
p = getP params
b <- randByteString bLen
let p0 = replicate (fromIntegral $ maxLength l) 0
m = b ++ [fromIntegral l] ++ msg ++ p0
mBin = addPadding $ intsToBits m
mTrin = concatMap binToTern $ chunksOf 3 mBin
sData = genSData h msg b params
r = bpgm sData params
r' = polyMod q $ reduceDegree n $ toPoly r * toPoly h
r4 = polyMod 4 r'
or4 = toOctets $ fromPoly r4
mask = mgf or4 params
m' = polyModInterval p $ toPoly mask + toPoly mTrin
e = polyMod q $ r' + m'
return $ fromPoly e
decrypt :: ParamSet
-> [Integer]
-> [Integer]
-> [Integer]
-> Maybe [Integer]
decrypt params f h e =
let n = getN params
p = getP params
q = getQ params
bLen = getDb params `div` 8
ci = polyMod p $ polyModInterval q $ reduceDegree n $ toPoly f * toPoly e
cR = polyMod q $ toPoly e polyModInterval p ci
cR4 = polyMod 4 cR
coR4 = toOctets $ fromPoly cR4
cMask = polyMod p $ toPoly $ mgf coR4 params
cMTrin = polyModInterval p $ ci cMask
cMTrin' = improperPolynomial n $ fromPoly cMTrin
cMBin = concatMap ternToBin $ chunksOf 2 $ take (length cMTrin' (length cMTrin' `mod` 2)) cMTrin'
cM = map bitsToInt $ chunksOf 8 $ take (length cMBin (length cMBin `mod` 8)) cMBin
(cb, rest) = splitAt bLen cM
([cl], rest') = splitAt (getLLen params) rest
(cm, rest'') = splitAt (fromIntegral cl) rest'
sData = genSData h cm cb params
cr = bpgm sData params
cR' = polyMod q $ reduceDegree n $ toPoly cr * toPoly h
validR = cR' == cR
validRemainder = all (==0) rest''
in checkValid cm validR validRemainder
checkValid :: [Integer] -> Bool -> Bool -> Maybe [Integer]
checkValid _ _ False = Nothing
checkValid _ False _ = Nothing
checkValid m _ _ = Just m
inverseMod :: Integer -> Integer -> Integer
inverseMod x y = case invertMod (fromIntegral x) (fromIntegral y) of
Just n -> fromIntegral n
_ -> error "Could not calculate inverseMod"
randByteString :: Int -> IO [Integer]
randByteString size = do
g <- newGenIO :: IO SystemRandom
case genBytes size g of
Left err -> error $ show err
Right (result, g2) -> return (unpackByteString result)
unpackByteString :: BC.ByteString -> [Integer]
unpackByteString str = map fromIntegral (B.unpack str)
binToTern :: [Integer] -> [Integer]
binToTern [0,0,0] = [0,0]
binToTern [0,0,1] = [0,1]
binToTern [0,1,0] = [0,1]
binToTern [0,1,1] = [1,0]
binToTern [1,0,0] = [1,1]
binToTern [1,0,1] = [1,1]
binToTern [1,1,0] = [1,0]
binToTern [1,1,1] = [1,1]
binToTern _ = error "Problem converting binary to trinary"
ternToBin :: [Integer] -> [Integer]
ternToBin [0,0] = [0,0,0]
ternToBin [0,1] = [0,0,1]
ternToBin [0,1] = [0,1,0]
ternToBin [1,0] = [0,1,1]
ternToBin [1,1] = [1,0,0]
ternToBin [1,1] = [1,0,1]
ternToBin [1,0] = [1,1,0]
ternToBin [1,1] = [1,1,1]
ternToBin _ = error " Problem converting trinary to binary"
addPadding :: [Integer] -> [Integer]
addPadding m = case length m `mod` 3 of
0 -> m
1 -> m ++ [0,0]
2 -> m ++ [0]
unpackByte :: Integer -> Integer -> [Integer]
unpackByte n b
| n < 0 = []
| otherwise = (b `div` (2 ^ n)):unpackByte (n1) (b `mod` 2 ^ n)
intToBits :: Integer -> [Integer]
intToBits = unpackByte 7
bigIntToBits :: Integer -> [Integer]
bigIntToBits = unpackByte 10
intsToBits :: [Integer] -> [Integer]
intsToBits = concatMap intToBits
bitsToInt :: [Integer] -> Integer
bitsToInt b = packByte 1 (reverse b)
where
packByte n b
| null b = 0
| otherwise = n * head b + packByte (n * 2) (tail b)
genRandPoly :: Int -> Int -> Int -> IO (Poly Integer)
genRandPoly n pos neg = do
poly <- setRandValues [] n pos neg
return $ toPoly poly
where
setRandValues lst n pos neg =
if n == 0 then return lst
else do
randVal <- randomIO :: IO Int
let randInRange = randVal `mod` n
if randInRange < pos
then setRandValues ((1):lst) (n 1) (pos 1) neg else if randInRange < (pos + neg) then setRandValues ((1):lst) (n 1) pos (neg 1) else setRandValues (0:lst) (n 1) pos neg
improperPolynomial :: Int -> [Integer] -> [Integer]
improperPolynomial n poly = poly ++ replicate (fromIntegral n length poly) 0
padInt8 :: [Integer] -> [Integer]
padInt8 lst = lst ++ replicate ((8 (length lst `mod` 8)) `mod` 8) 0
toOctets :: [Integer] -> [Integer]
toOctets lst =
let int2s = concatMap (reverse . take 2 . reverse . unpackByte 7) lst
in map (bitsToInt . padInt8) $ chunksOf 8 int2s
genParams :: String
-> ParamSet
genParams bit_level
| bit_level == "EES401EP1" = ParamSet {getN = 401, getP = 3, getQ = 2048, getDf = 113, getDg = 133, getLLen = 1, getDb = 112, getMaxMsgLenBytes = 60, getBufferLenBits = 600, getBufferLenTrits = 400, getDm0 = 113, getShaLvl = 1, getDr = 113, getC = 11, getMinCallsR = 32, getMinCallsMask = 9, getOID = [0,2,4], getPkLen = 114, getBitLvl = 112}
| bit_level == "EES449EP1" = ParamSet {getN = 449, getP = 3, getQ = 2048, getDf = 134, getDg = 149, getLLen = 1, getDb = 128, getMaxMsgLenBytes = 67, getBufferLenBits = 672, getBufferLenTrits = 448, getDm0 = 134, getShaLvl = 1, getDr = 134, getC = 9, getMinCallsR = 31, getMinCallsMask = 9, getOID = [0,3,3], getPkLen = 128, getBitLvl = 128}
| bit_level == "EES677EP1" = ParamSet {getN = 677, getP = 3, getQ = 2048, getDf = 157, getDg = 225, getLLen = 1, getDb = 192, getMaxMsgLenBytes = 101, getBufferLenBits = 1008, getBufferLenTrits = 676, getDm0 = 157, getShaLvl = 256, getDr = 157, getC = 11, getMinCallsR = 27, getMinCallsMask = 9, getOID = [0,5,3], getPkLen = 192, getBitLvl = 192}
| bit_level == "EES1087EP2" = ParamSet {getN = 1087, getP = 3, getQ = 2048, getDf = 120, getDg = 362, getLLen = 1, getDb = 256, getMaxMsgLenBytes = 170, getBufferLenBits = 1624, getBufferLenTrits = 1086, getDm0 = 120, getShaLvl = 256, getDr = 120, getC = 13, getMinCallsR = 25, getMinCallsMask = 14, getOID = [0,6,3], getPkLen = 256, getBitLvl = 256}
| bit_level == "EES541EP1" = ParamSet {getN = 541, getP = 3, getQ = 2048, getDf = 49, getDg = 180, getLLen = 1, getDb = 112, getMaxMsgLenBytes = 86, getBufferLenBits = 808, getBufferLenTrits = 540, getDm0 = 49, getShaLvl = 1, getDr = 49, getC = 12, getMinCallsR = 15, getMinCallsMask = 11, getOID = [0,2,5], getPkLen = 112, getBitLvl = 112}
| bit_level == "EES613EP1" = ParamSet {getN = 613, getP = 3, getQ = 2048, getDf = 55, getDg = 204, getLLen = 1, getDb = 128, getMaxMsgLenBytes = 97, getBufferLenBits = 912, getBufferLenTrits = 612, getDm0 = 55, getShaLvl = 1, getDr = 55, getC = 11, getMinCallsR = 16, getMinCallsMask = 13, getOID = [0,3,4], getPkLen = 128, getBitLvl = 128}
| bit_level == "EES887EP1" = ParamSet {getN = 887, getP = 3, getQ = 2048, getDf = 81, getDg = 295, getLLen = 1, getDb = 192, getMaxMsgLenBytes = 141, getBufferLenBits = 1328, getBufferLenTrits = 886, getDm0 = 81, getShaLvl = 256, getDr = 81, getC = 10, getMinCallsR = 13, getMinCallsMask = 12, getOID = [0,5,4], getPkLen = 192, getBitLvl = 192}
| bit_level == "EES1171EP1" = ParamSet {getN = 1171, getP = 3, getQ = 2048, getDf = 106, getDg = 390, getLLen = 1, getDb = 256, getMaxMsgLenBytes = 186, getBufferLenBits = 1752, getBufferLenTrits = 1170, getDm0 = 106, getShaLvl = 256, getDr = 106, getC = 10, getMinCallsR = 20, getMinCallsMask = 15, getOID = [0,6,4], getPkLen = 256, getBitLvl = 256}
| bit_level == "EES659EP1" = ParamSet {getN = 659, getP = 3, getQ = 2048, getDf = 38, getDg = 219, getLLen = 1, getDb = 112, getMaxMsgLenBytes = 108, getBufferLenBits = 984, getBufferLenTrits = 658, getDm0 = 38, getShaLvl = 1, getDr = 38, getC = 11, getMinCallsR = 11, getMinCallsMask = 14, getOID = [0,2,6], getPkLen = 112, getBitLvl = 112}
| bit_level == "EES761EP2" = ParamSet {getN = 761, getP = 3, getQ = 2048, getDf = 42, getDg = 253, getLLen = 1, getDb = 128, getMaxMsgLenBytes = 125, getBufferLenBits = 1136, getBufferLenTrits = 760, getDm0 = 42, getShaLvl = 1, getDr = 42, getC = 12, getMinCallsR = 13, getMinCallsMask = 16, getOID = [0,3,5], getPkLen = 128, getBitLvl = 128}
| bit_level == "EES1087EP1" = ParamSet {getN = 1087, getP = 3, getQ = 2048, getDf = 63, getDg = 362, getLLen = 1, getDb = 192, getMaxMsgLenBytes = 178, getBufferLenBits = 1624, getBufferLenTrits = 1086, getDm0 = 63, getShaLvl = 256, getDr = 63, getC = 13, getMinCallsR = 13, getMinCallsMask = 14, getOID = [0,5,5], getPkLen = 192, getBitLvl = 192}
| bit_level == "EES1499EP1" = ParamSet {getN = 1499, getP = 3, getQ = 2048, getDf = 79, getDg = 499, getLLen = 1, getDb = 256, getMaxMsgLenBytes = 247, getBufferLenBits = 2240, getBufferLenTrits = 1498, getDm0 = 79, getShaLvl = 256, getDr = 79, getC = 13, getMinCallsR = 17, getMinCallsMask = 19, getOID = [0,6,5], getPkLen = 256, getBitLvl = 256}
| otherwise = error "Unsupported Parameter Set"
data ParamSet = ParamSet {
getN :: Int,
getP :: Integer,
getQ :: Integer,
getDf :: Int,
getDg :: Int,
getLLen :: Int,
getDb :: Int,
getMaxMsgLenBytes :: Int,
getBufferLenBits :: Int,
getBufferLenTrits :: Int,
getDm0 :: Int,
getShaLvl :: Int,
getDr :: Int,
getC :: Int,
getMinCallsR :: Integer,
getMinCallsMask :: Int,
getOID :: [Int],
getPkLen :: Int,
getBitLvl :: Int
} deriving (Show)
getSHA :: ParamSet -> ([Integer] -> [Integer])
getSHA params = case (getShaLvl params) of
256 -> sha256Octets
1 -> sha1Octets
_ -> error "Unsupported SHA function"
getHLen :: ParamSet -> Int
getHLen params = case (getShaLvl params) of
256 -> 32
1 -> 20
_ -> error "Unsupported SHA function"