{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MultiWayIf #-}
module Codec.Crypto.RSA.Pure(
RSAError(..)
, HashInfo(..)
, PrivateKey(..)
, PublicKey(..)
, generateKeyPair
, encrypt
, encryptOAEP
, encryptPKCS
, decrypt
, decryptOAEP
, decryptPKCS
, sign
, verify
, MGF
, generateMGF1
, rsaes_oaep_encrypt
, rsaes_oaep_decrypt
, rsaes_pkcs1_v1_5_encrypt
, rsaes_pkcs1_v1_5_decrypt
, rsassa_pkcs1_v1_5_sign
, rsassa_pkcs1_v1_5_verify
, hashSHA1
, hashSHA224, hashSHA256, hashSHA384, hashSHA512
, largeRandomPrime
, generatePQ
, chunkify
, os2ip, i2osp
, rsa_dp, rsa_ep
, rsa_vp1, rsa_sp1
, modular_inverse
, modular_exponentiation
, randomBS, randomNZBS
)
where
import Control.Exception
import Control.Monad
import Crypto.Random
import Crypto.Types.PubKey.RSA
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put
import Data.Bits
import Data.ByteString.Lazy(ByteString)
import qualified Data.ByteString.Lazy as BS
import Data.Digest.Pure.SHA
import Data.Int
import Data.Typeable
data RSAError = RSAError String
| RSAKeySizeTooSmall
| RSAIntegerTooLargeToPack
| RSAMessageRepOutOfRange
| RSACipherRepOutOfRange
| RSAMessageTooShort
| RSAMessageTooLong
| RSAMaskTooLong
| RSAIncorrectSigSize
| RSAIncorrectMsgSize
| RSADecryptionError
| RSAGenError GenError
deriving (Eq, Show, Typeable)
instance Exception RSAError
data HashInfo = HashInfo {
algorithmIdent :: ByteString
, hashFunction :: ByteString -> ByteString
}
instance Show SystemRandom where
show _ = "SystemRandom"
class RSAKey a where
genKeySize :: a -> Int
instance RSAKey PublicKey where
genKeySize = public_size
instance RSAKey PrivateKey where
genKeySize = private_size
instance Binary PublicKey where
put pk = do sizeBS <- failOnError (i2osp (public_size pk) 8)
nBS <- failOnError (i2osp (public_n pk) (public_size pk))
putLazyByteString sizeBS
putLazyByteString nBS
get = do len <- (fromIntegral . os2ip) `fmap` getLazyByteString 8
n <- os2ip `fmap` getLazyByteString len
return (PublicKey (fromIntegral len) n 65537)
instance Binary PrivateKey where
put pk = do put (private_pub pk)
dBS <- failOnError (i2osp (private_d pk) (public_size (private_pub pk)))
putLazyByteString dBS
get = do pub <- get
d <- os2ip `fmap` getLazyByteString (fromIntegral (public_size pub))
return (PrivateKey pub d 0 0 0 0 0)
failOnError :: (Monad m, Show a) => Either a b -> m b
failOnError (Left e) = fail (show e)
failOnError (Right b) = return b
generateKeyPair :: CryptoRandomGen g =>
g -> Int ->
Either RSAError (PublicKey, PrivateKey, g)
generateKeyPair g sizeBits = do
let keyLength = fromIntegral (sizeBits `div` 8)
(p, q, g') <- generatePQ g keyLength
let n = p * q
phi = (p - 1) * (q - 1)
e = 65537
d = modular_inverse e phi
let publicKey = PublicKey keyLength n e
privateKey = PrivateKey publicKey d p q 0 0 0
return (publicKey, privateKey, g')
sign :: PrivateKey -> ByteString -> Either RSAError ByteString
sign = rsassa_pkcs1_v1_5_sign hashSHA256
verify :: PublicKey ->
ByteString ->
ByteString ->
Either RSAError Bool
verify = rsassa_pkcs1_v1_5_verify hashSHA256
encrypt :: CryptoRandomGen g =>
g -> PublicKey -> ByteString ->
Either RSAError (ByteString, g)
encrypt g k m = encryptOAEP g sha256' (generateMGF1 sha256') BS.empty k m
where sha256' = bytestringDigest . sha256
encryptOAEP :: CryptoRandomGen g =>
g ->
(ByteString -> ByteString) ->
MGF ->
ByteString ->
PublicKey ->
ByteString ->
Either RSAError (ByteString, g)
encryptOAEP g hash mgf l k m =
do unless ((keySize - (2 * hashLength) - 2) > 0) $ Left RSAKeySizeTooSmall
let chunks = chunkBSForOAEP k hash m
(chunks', g') <- mapM' g chunks (\ x -> rsaes_oaep_encrypt x hash mgf k l)
return (BS.concat chunks', g')
where
keySize = public_size k
hashLength = fromIntegral (BS.length (hash BS.empty))
encryptPKCS :: CryptoRandomGen g =>
g -> PublicKey -> ByteString ->
Either RSAError (ByteString, g)
encryptPKCS g k m =
do let chunks = chunkBSForPKCS k m
(chunks', g') <- mapM' g chunks (\ x -> rsaes_pkcs1_v1_5_encrypt x k)
return (BS.concat chunks', g')
mapM' :: CryptoRandomGen g =>
g -> [ByteString] ->
(g -> ByteString -> Either RSAError (ByteString, g)) ->
Either RSAError ([ByteString], g)
mapM' g [] _ = Right ([], g)
mapM' g (x:rest) f =
do (x', g') <- f g x
(rest', g'') <- mapM' g' rest f
return (x':rest', g'')
decrypt :: PrivateKey -> ByteString -> Either RSAError ByteString
decrypt k m = decryptOAEP sha256' (generateMGF1 sha256') BS.empty k m
where sha256' = bytestringDigest . sha256
decryptOAEP :: (ByteString -> ByteString) ->
MGF ->
ByteString ->
PrivateKey ->
ByteString ->
Either RSAError ByteString
decryptOAEP hash mgf l k m =
do let chunks = chunkify m (fromIntegral (private_size k))
chunks' <- forM chunks (rsaes_oaep_decrypt hash mgf k l)
return (BS.concat chunks')
decryptPKCS :: PrivateKey -> ByteString -> Either RSAError ByteString
decryptPKCS k m =
do let chunks = chunkify m (fromIntegral (private_size k))
chunks' <- forM chunks (rsaes_pkcs1_v1_5_decrypt k)
return (BS.concat chunks')
chunkBSForOAEP :: RSAKey k =>
k ->
(ByteString -> ByteString) ->
ByteString ->
[ByteString]
chunkBSForOAEP k hash bs = chunkify bs chunkSize
where
chunkSize = fromIntegral (genKeySize k) - (2 * hashLen) - 2
hashLen = BS.length (hash BS.empty)
chunkBSForPKCS :: RSAKey k => k -> ByteString -> [ByteString]
chunkBSForPKCS k bstr = chunkify bstr (fromIntegral (genKeySize k) - 11)
chunkify :: ByteString -> Int64 -> [ByteString]
chunkify bs size
| BS.length bs == 0 = []
| otherwise = let (start, end) = BS.splitAt size bs
in start : chunkify end size
rsaes_oaep_encrypt :: CryptoRandomGen g =>
g ->
(ByteString->ByteString) ->
MGF ->
PublicKey ->
ByteString ->
ByteString ->
Either RSAError (ByteString, g)
rsaes_oaep_encrypt g hash mgf k l m =
do let hashLength = fromIntegral (BS.length (hash BS.empty))
keySize = public_size k
msgLength = fromIntegral (BS.length m)
when (msgLength > (keySize - (2 * hashLength) - 2)) $
Left RSAMessageTooLong
let lHash = hash l
let zeros = BS.repeat 0
numZeros = keySize - msgLength - (2 * hashLength) - 2
ps = BS.take (fromIntegral numZeros) zeros
let db = BS.concat [lHash, ps, BS.singleton 1, m]
(seed, g') <- randomBS g hashLength
dbMask <- mgf seed (fromIntegral (keySize - hashLength - 1))
let maskedDB = db `xorBS` dbMask
seedMask <- mgf maskedDB (fromIntegral hashLength)
let maskedSeed = seed `xorBS` seedMask
let em = BS.concat [BS.singleton 0, maskedSeed, maskedDB]
let m_i = os2ip em
c_i <- rsa_ep (public_n k) (public_e k) m_i
c <- i2osp c_i (public_size k)
return (c, g')
rsaes_oaep_decrypt :: (ByteString->ByteString) ->
MGF ->
PrivateKey ->
ByteString ->
ByteString ->
Either RSAError ByteString
rsaes_oaep_decrypt hash mgf k l c =
do let hashLength = BS.length (hash BS.empty)
keySize = private_size k
unless (BS.length c == fromIntegral keySize) $
Left RSADecryptionError
unless (fromIntegral keySize >= ((2 * hashLength) + 2)) $
Left RSADecryptionError
let c_ip = os2ip c
m_ip <- rsa_dp (private_n k) (private_d k) c_ip
em <- i2osp m_ip keySize
let lHash = hash l
let (y, seed_db) = BS.splitAt 1 em
(maskedSeed, maskedDB) = BS.splitAt (fromIntegral hashLength) seed_db
seedMask <- mgf maskedDB hashLength
let seed = maskedSeed `xorBS` seedMask
dbMask <- mgf seed (fromIntegral keySize - hashLength - 1)
let db = maskedDB `xorBS` dbMask
let (lHash', ps_o_m) = BS.splitAt hashLength db
(ps, o_m) = BS.span (== 0) ps_o_m
(o, m) = BS.splitAt 1 o_m
unless (BS.unpack o == [1]) $ Left RSADecryptionError
unless (lHash' == lHash) $ Left RSADecryptionError
unless (BS.unpack y == [0]) $ Left RSADecryptionError
unless (BS.all (== 0) ps) $ Left RSADecryptionError
return m
rsaes_pkcs1_v1_5_encrypt :: CryptoRandomGen g =>
g ->
PublicKey ->
ByteString ->
Either RSAError (ByteString, g)
rsaes_pkcs1_v1_5_encrypt g k m =
do unless (fromIntegral (BS.length m) <= (public_size k - 11)) $
Left RSAIncorrectMsgSize
(ps, g') <- randomNZBS g (public_size k - fromIntegral (BS.length m) - 3)
let em = BS.concat [BS.singleton 0, BS.singleton 2, ps, BS.singleton 0, m]
let m' = os2ip em
c_i <- rsa_ep (public_n k) (public_e k) m'
res <- i2osp c_i (fromIntegral (public_size k))
return (res, g')
rsaes_pkcs1_v1_5_decrypt :: PrivateKey -> ByteString ->
Either RSAError ByteString
rsaes_pkcs1_v1_5_decrypt k c =
do unless (fromIntegral (BS.length c) == private_size k) $
Left RSAIncorrectMsgSize
let c_i = os2ip c
m_i <- rsa_dp (private_n k) (private_d k) c_i
em <- i2osp m_i (private_size k)
let (zt, ps_z_m) = BS.splitAt 2 em
(ps, z_m) = BS.span (/= 0) ps_z_m
(z, m) = BS.splitAt 1 z_m
when (BS.unpack zt /= [0,2]) $ Left RSADecryptionError
when (BS.unpack z /= [0]) $ Left RSADecryptionError
when (BS.length ps < 8 ) $ Left RSADecryptionError
return m
rsassa_pkcs1_v1_5_sign :: HashInfo ->
PrivateKey ->
ByteString ->
Either RSAError ByteString
rsassa_pkcs1_v1_5_sign hi k m =
do em <- emsa_pkcs1_v1_5_encode hi m (private_size k)
let m_i = os2ip em
s <- rsa_sp1 (private_n k) (private_d k) m_i
sig <- i2osp s (private_size k)
return sig
rsassa_pkcs1_v1_5_verify :: HashInfo ->
PublicKey ->
ByteString ->
ByteString ->
Either RSAError Bool
rsassa_pkcs1_v1_5_verify hi k m s
| BS.length s /= fromIntegral (public_size k) = Left RSAIncorrectSigSize
| otherwise =
do let s_i = os2ip s
m_i <- rsa_vp1 (public_n k) (public_e k) s_i
em <- i2osp m_i (public_size k)
em' <- emsa_pkcs1_v1_5_encode hi m (public_size k)
return (em == em')
type MGF = ByteString -> Int64 -> Either RSAError ByteString
generateMGF1 :: (ByteString -> ByteString) -> MGF
generateMGF1 hash mgfSeed maskLen
| BS.length mgfSeed > ((2 ^ (32::Integer)) * hLen) = Left RSAMaskTooLong
| otherwise = loop BS.empty 0
where
hLen = BS.length (hash BS.empty)
endCounter = (maskLen `divCeil` hLen) - 1
loop t counter
| counter > endCounter = Right (BS.take maskLen t)
| otherwise = do c <- i2osp counter 4
let bs = mgfSeed `BS.append` c
t' = t `BS.append` hash bs
loop t' (counter + 1)
i2osp :: Integral a => a -> Int -> Either RSAError ByteString
i2osp x len | isTooLarge = Left RSAIntegerTooLargeToPack
| otherwise = Right (padding `BS.append` digits)
where
isTooLarge = (fromIntegral x :: Integer) >=
(256 ^ (fromIntegral len :: Integer))
padding = BS.replicate (fromIntegral len - BS.length digits) 0
digits = BS.reverse (BS.unfoldr digitize x)
digitize 0 = Nothing
digitize v = let (q, r) = divMod v 256
in Just (fromIntegral r, q)
os2ip :: ByteString -> Integer
os2ip = BS.foldl (\ a b -> (256 * a) + (fromIntegral b)) 0
rsa_ep :: Integer -> Integer -> Integer -> Either RSAError Integer
rsa_ep n _ m | (m < 0) || (m >= n) = Left RSAMessageRepOutOfRange
rsa_ep n e m = Right (modular_exponentiation m e n)
rsa_dp :: Integer -> Integer -> Integer -> Either RSAError Integer
rsa_dp n _ c | (c < 0) || (c >= n) = Left RSACipherRepOutOfRange
rsa_dp n d c = Right (modular_exponentiation c d n)
rsa_sp1 :: Integer -> Integer -> Integer -> Either RSAError Integer
rsa_sp1 n _ m | (m < 0) || (m >= n) = Left RSAMessageRepOutOfRange
rsa_sp1 n d m = Right (modular_exponentiation m d n)
rsa_vp1 :: Integer -> Integer -> Integer -> Either RSAError Integer
rsa_vp1 n _ s | (s < 0) || (s >= n) = Left RSACipherRepOutOfRange
rsa_vp1 n e s = Right (modular_exponentiation s e n)
emsa_pkcs1_v1_5_encode :: HashInfo -> ByteString -> Int ->
Either RSAError ByteString
emsa_pkcs1_v1_5_encode (HashInfo ident hash) m emLen
| fromIntegral emLen < (tLen + 1) = Left RSAMessageTooShort
| otherwise = Right em
where
h = hash m
t = ident `BS.append` h
tLen = BS.length t
ps = BS.replicate (fromIntegral emLen - tLen - 3) 0xFF
em = BS.concat [BS.singleton 0x00,BS.singleton 0x01,ps,BS.singleton 0x00,t]
xorBS :: ByteString -> ByteString -> ByteString
xorBS a b = BS.pack (BS.zipWith xor a b)
divCeil :: Integral a => a -> a -> a
divCeil a b = let (q, r) = divMod a b
in if r /= 0 then (q + 1) else q
generatePQ :: CryptoRandomGen g =>
g ->
Int ->
Either RSAError (Integer, Integer, g)
generatePQ g len
| len < 2 = Left RSAKeySizeTooSmall
| otherwise = do (baseP, g') <- largeRandomPrime g (len `div` 2)
(baseQ, g'') <- largeRandomPrime g' (len - (len `div` 2))
case () of
() | baseP == baseQ -> generatePQ g'' len
| baseP < baseQ -> return (baseQ, baseP, g'')
| otherwise -> return (baseP, baseQ, g'')
largeRandomPrime :: CryptoRandomGen g =>
g -> Int ->
Either RSAError (Integer, g)
largeRandomPrime g len =
do (h_t, g') <- randomBS g 2
let [startH, startT] = BS.unpack h_t
(startMids, g'') <- randomBS g' (len - 2)
let bstr = BS.concat [BS.singleton (startH .|. 0xc0),
startMids, BS.singleton (startT .|. 1)]
findNextPrime g'' (os2ip bstr)
randomBS :: CryptoRandomGen g => g -> Int -> Either RSAError (ByteString, g)
randomBS g n =
case genBytes n g of
Left e -> Left (RSAGenError e)
Right (bs, g') -> Right (BS.fromChunks [bs], g')
randomNZBS :: CryptoRandomGen g => g -> Int -> Either RSAError (ByteString, g)
randomNZBS gen 0 = return (BS.empty, gen)
randomNZBS gen size =
do (bstr, gen') <- randomBS gen size
let nzbstr = BS.filter (/= 0) bstr
(rest, gen'') <- randomNZBS gen' (size - fromIntegral (BS.length nzbstr))
return (nzbstr `BS.append` rest, gen'')
findNextPrime :: CryptoRandomGen g =>
g -> Integer ->
Either RSAError (Integer, g)
findNextPrime g n
| even n = findNextPrime g (n + 1)
| n `mod` 65537 == 1 = findNextPrime g (n + 2)
| otherwise = case isProbablyPrime g n of
Left e -> Left e
Right (True, g') -> Right (n, g')
Right (False, g') -> findNextPrime g' (n + 2)
isProbablyPrime :: CryptoRandomGen g =>
g ->
Integer ->
Either RSAError (Bool, g)
isProbablyPrime g n
| n < 541 = Right (n `elem` small_primes, g)
| any (\ x -> n `mod` x == 0) small_primes = Right (False, g)
| otherwise = millerRabin g n 100
small_primes :: [Integer]
small_primes = [
2, 3, 5, 7, 11, 13, 17, 19, 23, 29,
31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
73, 79, 83, 89, 97, 101, 103, 107, 109, 113,
127, 131, 137, 139, 149, 151, 157, 163, 167, 173,
179, 181, 191, 193, 197, 199, 211, 223, 227, 229,
233, 239, 241, 251, 257, 263, 269, 271, 277, 281,
283, 293, 307, 311, 313, 317, 331, 337, 347, 349,
353, 359, 367, 373, 379, 383, 389, 397, 401, 409,
419, 421, 431, 433, 439, 443, 449, 457, 461, 463,
467, 479, 487, 491, 499, 503, 509, 521, 523, 541,
547, 557, 563, 569, 571, 577, 587, 593, 599, 601,
607, 613, 617, 619, 631, 641, 643, 647, 653, 659,
661, 673, 677, 683, 691, 701, 709, 719, 727, 733,
739, 743, 751, 757, 761, 769, 773, 787, 797, 809,
811, 821, 823, 827, 829, 839, 853, 857, 859, 863,
877, 881, 883, 887, 907, 911, 919, 929, 937, 941,
947, 953, 967, 971, 977, 983, 991, 997, 1009, 1013,
1019, 1021, 1031, 1033, 1039, 1049, 1051, 1061, 1063, 1069,
1087, 1091, 1093, 1097, 1103, 1109, 1117, 1123, 1129, 1151,
1153, 1163, 1171, 1181, 1187, 1193, 1201, 1213, 1217, 1223
]
millerRabin :: CryptoRandomGen g =>
g ->
Integer ->
Int ->
Either RSAError (Bool, g)
millerRabin g n k
| n <= 0 = Left (RSAError "Primality test on negative number or 0.")
| n == 1 = Right (False, g)
| n == 2 = Right (True, g)
| n == 3 = Right (True, g)
| otherwise =
let (s, d) = oddify 0 (n - 1)
in checkLoop g s d k
where
generateSize = bitsize (n - 2) 8 `div` 8
checkLoop :: CryptoRandomGen g =>
g -> Integer -> Integer -> Int ->
Either RSAError (Bool, g)
checkLoop g' _ _ 0 = Right (True, g')
checkLoop g' s d c =
case genBytes generateSize g' of
Left e -> Left (RSAGenError e)
Right (bstr, g'') ->
let a = os2ip (BS.fromStrict bstr)
x = modular_exponentiation a d n
in if | (a < 2) -> checkLoop g'' s d c
| (a > (n - 2)) -> checkLoop g'' s d c
| x == 1 -> checkLoop g'' s d (c - 1)
| x == (n - 1) -> checkLoop g'' s d (c - 1)
| otherwise -> checkWitnesses g'' s d x c (s - 1)
checkWitnesses g'' _ _ _ _ 0 = Right (False, g'')
checkWitnesses g'' s d x c1 c2 =
case (x * x) `mod` n of
1 -> Right (False, g'')
y | y == (n - 1) -> checkLoop g'' s d (c1 - 1)
_ -> checkWitnesses g'' s d x c1 (c2 - 1)
oddify s x | testBit x 0 = (s, x)
| otherwise = oddify (s + 1) (x `shiftR` 1)
bitsize v x | (1 `shiftL` x) > v = x
| otherwise = bitsize v (x + 8)
modular_exponentiation :: Integer -> Integer -> Integer -> Integer
modular_exponentiation x y m = m_e_loop x y 1
where
m_e_loop _ 0 result = result
m_e_loop b e result = m_e_loop b' e' result'
where
b' = (b * b) `mod` m
e' = e `shiftR` 1
result' = if testBit e 0 then (result * b) `mod` m else result
modular_inverse :: Integer ->
Integer ->
Integer
modular_inverse e phi = x `mod` phi
where (_, x, _) = extended_euclidean e phi
extended_euclidean :: Integer -> Integer -> (Integer, Integer, Integer)
extended_euclidean a b | d < 0 = (-d, -x, -y)
| otherwise = (d, x, y)
where
(d, x, y) = egcd a b
egcd :: Integer -> Integer -> (Integer, Integer, Integer)
egcd 0 b = (b, 0, 1)
egcd a b = let (g, y, x) = egcd (b `mod` a) a
in (g, x - ((b `div` a) * y), y)
hashSHA1 :: HashInfo
hashSHA1 = HashInfo {
algorithmIdent = BS.pack [0x30,0x21,0x30,0x09,0x06,0x05,0x2b,0x0e,0x03,
0x02,0x1a,0x05,0x00,0x04,0x14]
, hashFunction = bytestringDigest . sha1
}
hashSHA224 :: HashInfo
hashSHA224 = HashInfo {
algorithmIdent = BS.pack [0x30,0x2d,0x30,0x0d,0x06,0x09,0x60,0x86,0x48,
0x01,0x65,0x03,0x04,0x02,0x04,0x05,0x00,0x04,
0x1c]
, hashFunction = bytestringDigest . sha224
}
hashSHA256 :: HashInfo
hashSHA256 = HashInfo {
algorithmIdent = BS.pack [0x30,0x31,0x30,0x0d,0x06,0x09,0x60,0x86,0x48,
0x01,0x65,0x03,0x04,0x02,0x01,0x05,0x00,0x04,
0x20]
, hashFunction = bytestringDigest . sha256
}
hashSHA384 :: HashInfo
hashSHA384 = HashInfo {
algorithmIdent = BS.pack [0x30,0x41,0x30,0x0d,0x06,0x09,0x60,0x86,0x48,
0x01,0x65,0x03,0x04,0x02,0x02,0x05,0x00,0x04,
0x30]
, hashFunction = bytestringDigest . sha384
}
hashSHA512 :: HashInfo
hashSHA512 = HashInfo {
algorithmIdent = BS.pack [0x30,0x51,0x30,0x0d,0x06,0x09,0x60,0x86,0x48,
0x01,0x65,0x03,0x04,0x02,0x03,0x05,0x00,0x04,
0x40]
, hashFunction = bytestringDigest . sha512
}