{-# OPTIONS_GHC -O2 #-}

-- | This module implements RSA encryption with PKCS1 padding.

module Codec.Encryption.PKCS1
	,i2osp) where

import Data.Bits
import Data.ByteString as BS hiding (take,split,zipWith)
import Data.ByteString.Char8 as BSC hiding (take,split,zipWith)
import System.Random

-- | The public key is contained in a certificate.
data PublicKey = PublicKey
	{publicN :: Integer -- ^ The modulus
	,publicE :: Integer -- ^ Public exponent
	} deriving (Show,Eq)

newtype PrivateKey = PrivateKey (Either PrivateKeySimple PrivateKeyComplex)

-- | This private key variant takes longer to decrypt.
data PrivateKeySimple = PrivateKeySimple
	{privateN :: Integer	-- ^ The modulus
	,privateD :: Integer	-- ^ Private exponent
	} deriving (Show,Eq)

-- | This private key variant is considerably faster.
data PrivateKeyComplex = PrivateKeyComplex
	{privateN' :: Integer	-- ^ The modulus
	,privateD' :: Integer	-- ^ The private exponent
	,privateP :: Integer	-- ^ The first prime number
	,privateQ :: Integer	-- ^ The second prime number
	,privateU :: Integer	-- ^ Go read a book about it for god sake
	} deriving (Show,Eq)

-- | Encrypts a ByteString using RSA with PKCS1 padding
encrypt :: RandomGen g => Int	-- ^ Size of the modulus in bytes
	-> PublicKey		-- ^ The public key to use
	-> g			-- ^ The padding is random, so we just need this
	-> ByteString		-- ^ The string to encrypt
	-> ByteString
encrypt sz key gen str = BS.concat $ zipWith (rsaPkcs1encrypt sz key) (gens gen) (splitBlock (sz-11) str)

-- | Decrypts a ByteString using RSA with PKCS1 padding. If the ByteString
--   couldn't be decrypted, nothing is returned.
decrypt :: Int		-- ^ Size of the modulus in bytes
	-> PrivateKey	-- ^ The private key to use
	-> ByteString	-- ^ The string to decrypt
	-> Maybe ByteString
decrypt sz key str = fmap BS.concat $ sequence (fmap (rsaPkcs1decrypt sz key) (splitBlock sz str))

gens :: RandomGen g => g -> [g]
gens g = let (g1,g2) = split g in g1:gens g2

rsaPkcs1encrypt :: RandomGen g => Int -> PublicKey -> g -> ByteString -> ByteString
rsaPkcs1encrypt sz key gen str = let
	ps = take (sz - BS.length str - 3) (fmap fromIntegral $ randomRs (1,255::Int) gen)
	em = BS.append (BS.pack $ 0x00:0x02:(ps++[0x00])) str
	in i2osp sz $ rsaep key (os2ip em)

rsaPkcs1decrypt :: Int -> PrivateKey -> ByteString -> Maybe ByteString
rsaPkcs1decrypt sz key str = let
	em = i2osp sz $ rsadp key (os2ip str)
	(begin,rest) = BS.splitAt 2 em
	(ps,msg) = BS.break (==0) rest
	in if begin == BS.pack [0x00,0x02] && BS.length ps >= 8 && not (BS.null msg)
		then Just (BS.tail msg)
		else Nothing

rsaep :: PublicKey -> Integer -> Integer
rsaep key m = expmod m (publicE key) (publicN key)

rsadp :: PrivateKey -> Integer -> Integer
rsadp (PrivateKey k) = either rsadpSimple rsadpComplex k

-- inefficient
rsadpSimple :: PrivateKeySimple -> Integer -> Integer
rsadpSimple key c = expmod c (privateD key) (privateN key)

-- efficient
rsadpComplex :: PrivateKeyComplex -> Integer -> Integer
rsadpComplex key c = let
	m1 = expmod c ((privateD' key) `mod` ((privateP key)-1)) (privateP key)
	m2 = expmod c ((privateD' key) `mod` ((privateQ key)-1)) (privateQ key)
	h  = (m1 - m2)*(privateU key) `mod` (privateP key)
	in m2 + (privateQ key)*h 

expmod :: Integer -> Integer -> Integer -> Integer
expmod a x m = expmod' a x m 1

expmod' :: Integer -> Integer -> Integer -> Integer -> Integer
expmod' _ 0 _ res = res
expmod' a x m res = expmod' ((a*a) `mod` m) (x `div` 2) m (if odd x
	then (res*a) `mod` m
	else res)

-- | Converts a block of bytes into a number
os2ip :: (Bits a,Integral a) => ByteString -> a
os2ip bs = let
	len = BS.length bs
	in fst $ BS.foldl' (\(cur,p) w -> 
		(cur + shiftL (fromIntegral w) (8*(len-1-p)),p+1)
		) (0,0) bs

-- | Converts a number into a block of bytes
i2osp :: (Bits a,Integral a) => Int -> a -> ByteString
i2osp len num = fst $ BS.unfoldrN len (\p -> if p >= 0
	then Just (fromIntegral $ shiftR num (p*8),p-1)
	else Nothing) (len-1)

splitBlock :: Int -> ByteString -> [ByteString]
splitBlock sz str
	| BS.null str = []
	| otherwise   = let (x,xs) = BS.splitAt sz str in x:splitBlock sz xs

{-# INLINE os2ip #-}
{-# INLINE i2osp #-}
{-# SPECIALIZE os2ip :: ByteString -> Integer #-}
{-# SPECIALIZE i2osp :: Int -> Integer -> ByteString #-}
{-# SPECIALIZE encrypt :: Int -> PublicKey -> StdGen -> ByteString -> ByteString #-}
{-# SPECIALIZE rsaPkcs1encrypt :: Int -> PublicKey -> StdGen -> ByteString -> ByteString #-}