module Crypto.Cipher.RSA
( Error(..)
, PublicKey(..)
, PrivateKey(..)
, decrypt
, encrypt
) where
import Control.Arrow (first)
import Crypto.Random
import Data.Bits
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Number.ModArithmetic (exponantiation_rtl_binary)
data Error =
MessageSizeIncorrect
| MessageTooLong
| MessageNotRecognized
| RandomGenFailure GenError
| KeyInternalError
deriving (Show,Eq)
data PublicKey = PublicKey
{ public_sz :: Int
, public_n :: Integer
, public_e :: Integer
} deriving (Show)
data PrivateKey = PrivateKey
{ private_sz :: Int
, private_n :: Integer
, private_d :: Integer
, private_p :: Integer
, private_q :: Integer
, private_dP :: Integer
, private_dQ :: Integer
, private_qinv :: Integer
} deriving (Show)
#if ! (MIN_VERSION_base(4,3,0))
instance Monad (Either Error) where
return = Right
(Left x) >>= _ = Left x
(Right x) >>= f = f x
#endif
padPKCS1 :: CryptoRandomGen g => g -> Int -> ByteString -> Either Error (ByteString, g)
padPKCS1 rng len m = do
(padding, rng') <- getRandomBytes rng (len B.length m 3)
return (B.concat [ B.singleton 0, B.singleton 2, padding, B.singleton 0, m ], rng')
unpadPKCS1 :: ByteString -> Either Error ByteString
unpadPKCS1 packed
| signal_error = Left MessageNotRecognized
| otherwise = Right m
where
(zt, ps0m) = B.splitAt 2 packed
(ps, zm) = B.span (/= 0) ps0m
(z, m) = B.splitAt 1 zm
signal_error = (B.unpack zt /= [0, 2]) || (B.unpack z /= [0]) || (B.length ps < 8)
dpSlow :: PrivateKey -> ByteString -> Either Error ByteString
dpSlow pk c = i2ospOf (private_sz pk) $ expmod (os2ip c) (private_d pk) (private_n pk)
dpFast :: PrivateKey -> ByteString -> Either Error ByteString
dpFast pk c = i2ospOf (private_sz pk) (m2 + h * (private_q pk))
where
iC = os2ip c
m1 = expmod iC (private_dP pk) (private_p pk)
m2 = expmod iC (private_dQ pk) (private_q pk)
h = ((private_qinv pk) * (m1 m2)) `mod` (private_p pk)
decrypt :: PrivateKey -> ByteString -> Either Error ByteString
decrypt pk c
| B.length c /= (private_sz pk) = Left MessageSizeIncorrect
| otherwise = dp pk c >>= unpadPKCS1
where dp = if private_p pk /= 0 && private_q pk /= 0 then dpFast else dpSlow
encrypt :: CryptoRandomGen g => g -> PublicKey -> ByteString -> Either Error (ByteString, g)
encrypt rng pk m
| B.length m > public_sz pk 11 = Left MessageTooLong
| otherwise = do
(em, rng') <- padPKCS1 rng (public_sz pk) m
c <- i2ospOf (public_sz pk) $ expmod (os2ip em) (public_e pk) (public_n pk)
return (c, rng')
getRandomBytes :: CryptoRandomGen g => g -> Int -> Either Error (ByteString, g)
getRandomBytes rng n = do
gend <- either (Left . RandomGenFailure) Right $ genBytes n rng
let (bytes, rng') = first (B.pack . filter (/= 0) . B.unpack) gend
let left = (n B.length bytes)
if left == 0
then return (bytes, rng')
else getRandomBytes rng' left >>= return . first (B.append bytes)
i2ospOf :: Int -> Integer -> Either Error ByteString
i2ospOf len m
| lenbytes < len = Right $ B.replicate (len lenbytes) 0 `B.append` bytes
| lenbytes == len = Right bytes
| otherwise = Left KeyInternalError
where
lenbytes = B.length bytes
bytes = i2osp m
os2ip :: ByteString -> Integer
os2ip = B.foldl' (\a b -> (256 * a) .|. (fromIntegral b)) 0
i2osp :: Integer -> ByteString
i2osp m = B.reverse $ B.unfoldr divMod256 m
where
divMod256 0 = Nothing
divMod256 n = Just (fromIntegral a,b) where (b,a) = n `divMod` 256
expmod :: Integer -> Integer -> Integer -> Integer
expmod = exponantiation_rtl_binary