module NTRU (keyGen112, keyGen128, keyGen192, keyGen256, encrypt112, encrypt128, encrypt192, encrypt256, decrypt112, decrypt128, decrypt192, decrypt256) where
import Data.Digest.Pure.SHA
import Data.List.Split
import Data.Sequence as Seq (index, update, empty, fromList, Seq)
import Data.Foldable as L (toList)
import Crypto.Random
import System.Random
import Math.Polynomial
import Math.NumberTheory.Moduli
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as BL
fromPoly :: (Num a, Eq a, Integral a) => Poly a -> [a]
fromPoly = polyCoeffs LE
toPoly :: (Num a, Eq a, Integral a) => [a] -> Poly a
toPoly = poly LE
polyCoef :: (Num a, Eq a, Integral a) => Poly a -> Int -> a
polyCoef p i = fromPoly p !! i
instance (Num a, Eq a) => Num (Poly a) where
f + g = addPoly f g
f * g = multPoly f g
negate = negatePoly
abs = undefined
signum = undefined
fromInteger = undefined
reduceDegree :: (Num a, Eq a, Integral a) => Int -> Poly a -> Poly a
reduceDegree n f =
let (f1,f2) = splitAt n (fromPoly f)
in toPoly f1 + toPoly f2
polyMod :: (Num a, Eq a, Integral a) => a -> Poly a -> Poly a
polyMod q f = toPoly $ map (`mod` q) (fromPoly f)
polyModInterval :: (Num a, Eq a, Integral a) => a -> Poly a -> Poly a
polyModInterval q f = toPoly $ map (\x -> intervalReduce $ x `mod` q) (fromPoly f)
where intervalReduce x = if x <= (q `div` 2) then x else x q
polyBigMod :: (Num a, Eq a, Integral a) => Int -> Poly a -> Poly a
polyBigMod q p = toPoly $ map fromIntegral $ fromPoly $ polyMod q $ toPoly $ map fromIntegral $ fromPoly p
xPow :: (Num a, Eq a, Integral a) => Int -> Poly a
xPow = powPoly x
divPolyMod :: (Num a, Eq a, Integral a) => a -> Poly a -> Poly a -> (Poly a, Poly a)
divPolyMod p a b =
let n = polyDegree b in
let u = inverseMod (polyCoef b n) p in
divLoop p b n u zero a
where
divLoop p b n u q r =
let d = polyDegree r in
if d < n then (polyMod p q, polyMod p r)
else
let v = scalePoly (u * polyCoef r d) (xPow (d n)) in
let r' = polyMod p $ r (v * b) in
let q' = polyMod p $ q + v in
divLoop p b n u q' r'
extendedEuclidean :: (Num a, Eq a, Integral a) => a -> Poly a -> Poly a -> (Poly a, Poly a)
extendedEuclidean p a b = extendedEuclideanLoop p one a zero b
where
extendedEuclideanLoop p u d v1 v3
| polyIsZero v3 = (d,u)
| otherwise =
let (q,t3) = divPolyMod p d v3 in
let t1 = polyMod p $ u q * v1 in
extendedEuclideanLoop p v1 v3 t1 t3
findInversable :: (Num a, Eq a, Integral a) => [Int] -> IO (Poly a, Poly a)
findInversable params = do
let n = getN params
let df = getDf params
a' <- genRandPoly n df df
let a = scalePoly (getP params) a' + one
let b = xPow n one
let (d, u) = extendedEuclidean 2 a b
if d == one then return (a, u) else findInversable params
inverseLift :: (Num a, Eq a, Integral a) => Poly a -> Poly a -> Int -> a -> Poly a
inverseLift a b deg = inverseLift' a b deg 2 11 where
inverseLift' a b deg n e q
| e == 0 = polyMod (2 ^ 11) b
| otherwise =
let b' = polyBigMod (2 ^ n) $ scalePoly 2 b (reduceDegree deg $! a * (reduceDegree deg $! (b * b)))
in inverseLift' a b' deg (2 * n) (e `div` 2) q
generateKeyPair :: (Num a, Eq a, Integral a) => [Int] -> IO ([a], [a])
generateKeyPair params = do
let n = getN params
dg = getDg params
q = getQ params
(f, u) <- findInversable params
let fq = inverseLift f u n (fromIntegral q)
g <- genRandPoly n dg (dg 1)
let pk = polyMod q $! reduceDegree n $! scalePoly (getP params) $! fq * g
return (fromPoly pk, fromPoly f)
genSData :: (Num a, Eq a, Integral a) => [a] -> [a] -> [a] -> [Int] -> [a]
genSData h msg b params =
let bh = concatMap bigIntToBits h in
let pkLen = getPkLen params in
let bhTrunc = take (pkLen (pkLen `mod` 8)) bh in
let hTrunc = map (fromIntegral . bitsToInt) (chunksOf 8 bhTrunc) in
let sData = map fromIntegral (getOID params) ++ msg ++ b ++ hTrunc in
sData
bpgm :: (Num a, Eq a, Integral a) => [a] -> [Int] -> [a]
bpgm seed params =
let (i, s) = igf ([], [], 0) seed params in
let r = Seq.update i 1 $ Seq.fromList $ replicate (getN params) 0 in
let t = getDr params in
let r' = rlooper s 1 r (t 1) params in
L.toList $ rlooper s (1) r' t params
rlooper :: (Num a, Eq a, Integral a) => ([a], [a], a) -> a -> Seq.Seq a -> Int -> [Int] -> Seq.Seq a
rlooper s val r 0 params = r
rlooper s val r t params =
let (i, s') = igf s [] params in
if Seq.index r i == 0
then (let r' = Seq.update i val r in rlooper s' val r' (t1) params)
else rlooper s' val r t params
igf :: (Num a, Eq a, Integral a) => ([a], [a], a) -> [a] -> [Int] -> (Int, ([a], [a], a))
igf state seed params =
let (z, buf, counter) = extractVariables state seed params
(i, buf', counter') = genIndex counter buf z params
s = (z, buf', counter')
n = getN params
in (i `mod` n, s)
extractVariables :: (Num a, Eq a, Integral a) => ([a], [a], a) -> [a] -> [Int] -> ([a], [a], a)
extractVariables state [] _ = state
extractVariables _ seed params = igfinit seed params
igfinit :: (Num a, Eq a, Integral a) => [a] -> [Int] -> ([a], [a], a)
igfinit seed params =
let minCallsR = getMinCallsR params
shaFn = getSHA params
z = shaFn seed
buf = buildM 0 minCallsR z shaFn []
in (z, buf, minCallsR)
genIndex :: (Num a, Eq a, Integral a) => a -> [a] -> [a] -> [Int] -> (Int, [a], a)
genIndex counter buf z params =
let remLen = length buf
c = getC params
n = getN params
shaFn = getSHA params
hLen = getHLen params
tmpLen = (c remLen)
cThreshold = counter + fromIntegral (ceiling (fromIntegral tmpLen / fromIntegral hLen))
(m, counter') = if remLen >= c
then (buf, counter)
else (buildM counter cThreshold z shaFn buf, cThreshold)
(b, buf') = splitAt c (buf ++ m)
i = fromIntegral $ bitsToInt b
in if i >= (2^c (2^c `mod` n))
then genIndex counter' buf' z params
else (i, buf', counter')
buildM :: (Num a, Eq a, Integral a) => a -> a -> [a] -> ([a]->[a]) -> [a] -> [a]
buildM count cThreshold z shaFn buf
| count >= cThreshold = buf
| otherwise =
let c = i2osp count 3
h = shaFn (z ++ c)
m = buf ++ intsToBits h
in buildM (count + 1) cThreshold z shaFn m
i2osp :: (Num a, Eq a, Integral a) => a -> a -> [a]
i2osp i n
| n == 0 = [i]
| otherwise = 0:i2osp i (n1)
bToStrict :: BL.ByteString -> B.ByteString
bToStrict = B.concat . BL.toChunks
sha1Octets :: (Num a, Eq a, Integral a) => [a] -> [a]
sha1Octets input = map fromIntegral $ B.unpack $ bToStrict $ bytestringDigest $ sha1 $ BL.pack $ map fromIntegral input
sha256Octets :: (Num a, Eq a, Integral a) => [a] -> [a]
sha256Octets input = map fromIntegral $ B.unpack $ bToStrict $ bytestringDigest $ sha256 $ BL.pack $ map fromIntegral input
mgf :: (Num a, Eq a, Integral a) => [a] -> [Int] -> [a]
mgf seed params =
let n = getN params in
let shaFn = getSHA params in
let z = shaFn seed in
let buf = buildBuffer 0 (getMinCallsR params) z shaFn [] in
let i = formatI buf in
take n $ finishI i n (getMinCallsR params) z shaFn
buildBuffer :: (Num a, Eq a, Integral a) => a -> a -> [a] -> ([a]->[a]) -> [a] -> [a]
buildBuffer counter minCallsR z shaFn buffer
| counter >= minCallsR = buffer
| otherwise = let octet_c = i2osp counter 3 in
let h = shaFn (z ++ octet_c) in
buildBuffer (counter + 1) minCallsR z shaFn (buffer ++ h)
toTrits :: (Num a, Eq a, Integral a) => a -> a -> [a]
toTrits n o
| n == 0 = []
| otherwise = (o `mod` 3):toTrits (n 1) ((o (o `mod` 3)) `div` 3)
finishI :: (Num a, Eq a, Integral a) => [a] -> Int -> a -> [a] -> ([a] -> [a]) -> [a]
finishI i n counter z shaFn
| fromIntegral (length i) >= n = i
| otherwise = let buf = buildBuffer counter (counter + 1) z shaFn [] in
let i' = formatI buf in
finishI i' n (counter + 1) z shaFn
formatI :: (Num a, Eq a, Integral a) => [a] -> [a]
formatI buf = concatMap (toTrits 5) $ filter (< 243) buf
encrypt :: (Num a, Eq a, Integral a) => [Int] -> [a] -> [a] -> IO [a]
encrypt params msg h =
let l = fromIntegral $ length msg
maxLength = getMaxMsgLenBytes params in
if l > maxLength then error "message too long"
else do
let bLen = getDb params `div` 8
dr = getDr params
n = getN params
q = getQ params
p = getP params
b <- randByteString bLen
let p0 = replicate (fromIntegral $ maxLength l) 0
m = b ++ [fromIntegral l] ++ msg ++ p0
mBin = addPadding $ intsToBits m
mTrin = concatMap binToTern $ chunksOf 3 mBin
sData = genSData h msg b params
r = bpgm sData params
r' = polyMod q $ reduceDegree n $ toPoly r * toPoly h
r4 = polyMod 4 r'
or4 = toOctets $ fromPoly r4
mask = mgf or4 params
m' = polyModInterval p $ toPoly mask + toPoly mTrin
e = polyMod q $ r' + m'
return $ fromPoly e
decrypt :: (Num a, Eq a, Integral a) => [Int] -> [a] -> [a] -> [a] -> [a]
decrypt params f h e =
let n = getN params
p = getP params
q = getQ params
bLen = getDb params `div` 8
ci = polyMod p $ polyModInterval q $ reduceDegree n $ toPoly f * toPoly e
cR = polyMod q $ toPoly e polyModInterval p ci
cR4 = polyMod 4 cR
coR4 = toOctets $ fromPoly cR4
cMask = polyMod p $ toPoly $ mgf coR4 params
cMTrin = polyModInterval p $ ci cMask
cMTrin' = improperPolynomial n $ fromPoly cMTrin
cMBin = concatMap ternToBin $ chunksOf 2 $ take (length cMTrin' (length cMTrin' `mod` 2)) cMTrin'
cM = map bitsToInt $ chunksOf 8 $ take (length cMBin (length cMBin `mod` 8)) cMBin
(cb, rest) = splitAt bLen cM
([cl], rest') = splitAt (getLLen params) rest
(cm, rest'') = splitAt (fromIntegral cl) rest'
sData = genSData h cm cb params
cr = bpgm sData params
cR' = polyMod q $ reduceDegree n $ toPoly cr * toPoly h
validR = cR' == cR
validRemainder = all (==0) rest''
in checkValid cm validR validRemainder
checkValid :: (Num a, Eq a, Integral a) => [a] -> Bool -> Bool -> [a]
checkValid _ _ False = error "Failure Checking Remainder of Message"
checkValid _ False _ = error "Failure Verifying Blinding Polynomial"
checkValid m _ _ = m
inverseMod :: (Num a, Eq a, Integral a) => a -> a -> a
inverseMod x y = case invertMod (fromIntegral x) (fromIntegral y) of
Just n -> fromIntegral n
_ -> error "Coukd not calculate inverseMod"
randByteString :: (Num a, Eq a, Integral a) => Int -> IO [a]
randByteString size = do
g <- newGenIO :: IO SystemRandom
case genBytes size g of
Left err -> error $ show err
Right (result, g2) -> return (unpackByteString result)
unpackByteString :: (Num a, Eq a, Integral a) => BC.ByteString -> [a]
unpackByteString str = map fromIntegral (B.unpack str)
binToTern :: (Num a, Eq a, Integral a) => [a] -> [a]
binToTern [0,0,0] = [0,0]
binToTern [0,0,1] = [0,1]
binToTern [0,1,0] = [0,1]
binToTern [0,1,1] = [1,0]
binToTern [1,0,0] = [1,1]
binToTern [1,0,1] = [1,1]
binToTern [1,1,0] = [1,0]
binToTern [1,1,1] = [1,1]
binToTern _ = error "Problem converting binary to trinary"
ternToBin :: (Num a, Eq a, Integral a) => [a] -> [a]
ternToBin [0,0] = [0,0,0]
ternToBin [0,1] = [0,0,1]
ternToBin [0,1] = [0,1,0]
ternToBin [1,0] = [0,1,1]
ternToBin [1,1] = [1,0,0]
ternToBin [1,1] = [1,0,1]
ternToBin [1,0] = [1,1,0]
ternToBin [1,1] = [1,1,1]
ternToBin _ = error " Problem converting trinary to binary"
addPadding :: (Num a, Eq a, Integral a) => [a] -> [a]
addPadding m = case length m `mod` 3 of
0 -> m
1 -> m ++ [0,0]
2 -> m ++ [0]
unpackByte :: (Num a, Eq a, Integral a) => a -> a -> [a]
unpackByte n b
| n < 0 = []
| otherwise = (b `div` (2 ^ n)):unpackByte (n1) (b `mod` 2 ^ n)
intToBits :: (Num a, Eq a, Integral a) => a -> [a]
intToBits = unpackByte 7
bigIntToBits :: (Num a, Eq a, Integral a) => a -> [a]
bigIntToBits = unpackByte 10
intsToBits :: (Num a, Eq a, Integral a) => [a] -> [a]
intsToBits = concatMap intToBits
bitsToInt :: (Num a, Eq a, Integral a) => [a] -> a
bitsToInt b = packByte 1 (reverse b)
where
packByte n b
| null b = 0
| otherwise = n * head b + packByte (n * 2) (tail b)
genRandPoly :: (Num a, Eq a, Integral a) => Int -> Int -> Int -> IO (Poly a)
genRandPoly n pos neg = do
poly <- setRandValues [] n pos neg
return $ toPoly poly
where
setRandValues lst n pos neg =
if n == 0 then return lst
else do
randVal <- randomIO :: IO Int
let randInRange = randVal `mod` n
if randInRange <= pos then setRandValues ((1):lst) (n 1) (neg 1) pos else if randInRange <= (pos + neg) then setRandValues (1:lst) (n 1) neg (pos 1) else setRandValues (0:lst) (n 1) neg pos
improperPolynomial :: (Num a, Eq a, Integral a) => Int -> [a] -> [a]
improperPolynomial n poly = poly ++ replicate (fromIntegral n length poly) 0
padInt8 :: (Num a, Eq a, Integral a) => [a] -> [a]
padInt8 lst = lst ++ replicate ((8 (length lst `mod` 8)) `mod` 8) 0
toOctets :: (Num a, Eq a, Integral a) => [a] -> [a]
toOctets lst =
let int2s = concatMap (reverse . take 2 . reverse . unpackByte 7) lst
in map (bitsToInt . padInt8) $ chunksOf 8 int2s
genParams :: (Num a, Eq a, Integral a) => a -> [Int]
genParams bit_level
| bit_level == 112 = [401,3,2048,113,133,1,112,60,600,400,113,1,113,11,32,0,0,2,4,114,112]
| bit_level == 128 = [449,3,2048,134,149,1,128,67,672,448,134,1,134,9,31,9,0,3,3,128,128]
| bit_level == 192 = [677,3,2048,157,225,1,192,101,1008,676,157,256,157,11,27,9,0,5,3,192,192]
| bit_level == 256 = [1087,3,2048,120,362,1,256,170,1624,1086,120,256,120,13,25,14,0,6,3,256,256]
| otherwise = error "BitLevel must be 112, 128, 192, 256"
getN = head
getP lst = fromIntegral $ lst!!1
getQ lst = fromIntegral $ lst!!2
getDf lst = lst!!3
getDg lst = lst!!4
getLLen lst = lst!!5
getDb lst = lst!!6
getMaxMsgLenBytes lst = lst!!7
getBufferLenBits lst = lst!!8
getBufferLenTrits lst = lst!!9
getDm0 lst = lst!!10
getSHA lst
| lst!!11 == 1 = sha1Octets
| otherwise = sha256Octets
getHLen lst
| lst!!11 == 1 = 20
| otherwise = 32
getDr lst = lst!!12
getC lst = lst!!13
getMinCallsR lst = fromIntegral $ lst!!14
getMinCallsMask lst = fromIntegral $ lst!!15
getOID lst = [lst!!16,lst!!17,lst!!18]
getPkLen lst = lst!!19
getLvl lst = lst!!20
keyGen112 :: (Num a, Eq a, Integral a) => IO ([a], [a])
keyGen112 = generateKeyPair (genParams 112)
keyGen128 :: (Num a, Eq a, Integral a) => IO ([a], [a])
keyGen128 = generateKeyPair (genParams 128)
keyGen192 :: (Num a, Eq a, Integral a) => IO ([a], [a])
keyGen192 = generateKeyPair (genParams 192)
keyGen256 :: (Num a, Eq a, Integral a) => IO ([a], [a])
keyGen256 = generateKeyPair (genParams 256)
encrypt112 :: (Num a, Eq a, Integral a) => [a]
-> [a]
-> IO [a]
encrypt112 = encrypt (genParams 112)
encrypt128 :: (Num a, Eq a, Integral a) => [a] -> [a] -> IO [a]
encrypt128 = encrypt (genParams 128)
encrypt192 :: (Num a, Eq a, Integral a) => [a] -> [a] -> IO [a]
encrypt192 = encrypt (genParams 192)
encrypt256 :: (Num a, Eq a, Integral a) => [a] -> [a] -> IO [a]
encrypt256 = encrypt (genParams 256)
decrypt112 :: (Num a, Eq a, Integral a) => [a]
-> [a]
-> [a]
-> [a]
decrypt112 = decrypt (genParams 112)
decrypt128 :: (Num a, Eq a, Integral a) => [a] -> [a] -> [a] -> [a]
decrypt128 = decrypt (genParams 128)
decrypt192 :: (Num a, Eq a, Integral a) => [a] -> [a] -> [a] -> [a]
decrypt192 = decrypt (genParams 192)
decrypt256 :: (Num a, Eq a, Integral a) => [a] -> [a] -> [a] -> [a]
decrypt256 = decrypt (genParams 256)