module Codec.Encryption.RSA.NumberTheory( 
inverse, extEuclGcd, simplePrimalityTest, getPrime, pg, isPrime, 
rabinMillerPrimalityTest, expmod, factor, testInverse, primes, (/|),
randomOctet
)    where
import System.Random(getStdRandom,randomR)
import Data.List(elemIndex)
import Data.Maybe(fromJust)
import Data.Char(chr,ord)
import Data.Bits(xor)
randomOctet :: Int -> IO( String )
randomOctet n
  | n < 0 = error "randomOctet argument doesn't meet preconditions"
  | otherwise = (sequence $ take n $ repeat $ getStdRandom (randomR( 0,255) )) 
                  >>= (return . (map chr) )
factor :: Integer -> [Int]
factor = factor_1
factor_1 :: Integer -> [Int]
factor_1 a = reverse . dropWhile (== 0) . reverse 
  . map (\x -> largestPower x a) . takeWhile (<= a ) $ primes
factor_2 :: Integer -> [Integer]
factor_2 a = 
  let 
    p = map (fromIntegral) . reverse . dropWhile (== 0) 
      . reverse . map (\x -> largestPower x a) 
	  . takeWhile (<= a `div` 2) $ primes
  in
    if (length p == 0)
    then (take ((fromIntegral . fromJust $ elemIndex a primes)1) (repeat 0)) 
	  ++ [1]
    else p
 
inverse :: Integer -> Integer -> Integer
inverse x n = (fst (extEuclGcd x n)) `mod` n
testInverse :: Integer ->Integer -> Bool
testInverse a b = ((inverse a b)*a) `mod` b == 1 
extEuclGcd :: Integer -> Integer -> (Integer,Integer)
extEuclGcd a b = extEuclGcd_iter a b (1,0) (0,1)
extEuclGcd_iter :: Integer -> Integer 
  -> (Integer,Integer) -> (Integer,Integer) -> (Integer,Integer)
extEuclGcd_iter a b (c1,c2) (d1,d2)
  |  (a > b) && (r1 == 0)  = (d1,d2)
  |  (a > b) && (r1 /= 0)  = extEuclGcd_iter 
    (a  (q1*b)) b (c1  (q1*d1), c2  (q1*d2)) (d1,d2)
  |  (a <= b) && (r2 == 0) = (c1,c2)
  |  (a <= b) && (r2 /= 0) = extEuclGcd_iter 
    a (b  (q2*a)) (c1,c2) ( d1  (q2*c1), d2 (q2*c2))
      where
        q1 = a `div` b
        q2 = b `div` a
        r1 = a `mod` b
        r2 = b `mod` a
getNumber :: Int -> IO Integer
getNumber n = do 
                 i <- getStdRandom ( randomR (0, a1 ) )
                 return (i+(2^(n1)))
              where
                  a = (2^(n1)) ::Integer
getPrime  :: Int -> IO Integer
getPrime nBits = do
                r <- getNumber nBits
                let p = if( 2 /| r ) then r else r+1
                pIsPrime <- isPrime p
                if( pIsPrime )
                   then return p
                   else getPrime nBits
pg :: Integer -> Integer -> Integer -> IO(Integer)
pg minimum maximum e = do
  p <- getStdRandom( randomR( minimum, maximum ) )
  pIsPrime <- isPrime p
  if( pIsPrime && (gcd p e) == 1 )
    then return p
    else pg minimum maximum e
isPrime :: Integer -> IO Bool
isPrime a
  | (a <= 1)    = return False
  | (a <= 2000) = return (simplePrimalityTest a)
  | otherwise   = if (simplePrimalityTest a)
                    then do 
                      test <- mapM rabinMillerPrimalityTest $ take 5 $ repeat a
                      return (and test)
                    else return False
simplePrimalityTest :: Integer -> Bool
simplePrimalityTest a = foldr (&&) True (map (/| a)(takeWhile (<it) primes))
  where it = min 2000 a
largestPower :: Integer -> Integer -> Int
largestPower x y = fromJust . elemIndex False 
  . map (\b -> (y `mod` x^b) == 0) $ [1..]
rabinMillerPrimalityTest :: Integer -> IO Bool
rabinMillerPrimalityTest p = rabinMillerPrimalityTest_iter_1 p b m
                                 where
                                   b = fromIntegral $ largestPower 2 (p1)
                                   m = (p1) `div` (2^b)
rabinMillerPrimalityTest_iter_1 :: Integer -> Integer -> Integer -> IO Bool
rabinMillerPrimalityTest_iter_1 p b m =
              do
                a <- getStdRandom ( randomR (0, 2000 ) )
                return (rabinMillerPrimalityTest_iter_2 p b 0 (expmod a m p))
rabinMillerPrimalityTest_iter_2 :: Integer -> Integer -> Integer -> Integer 
  -> Bool
rabinMillerPrimalityTest_iter_2 p b j z 
  | (z == 1)   || (z == p1)       = True
  | (j > 0)    && (z == 1)         = False
  | (j+1 < b)  && (z /= p1)       = 
    (rabinMillerPrimalityTest_iter_2 p b (j+1) ((z^2) `mod` p ))
  | z == p  1                     = True
  | (j+1 == b) && (z /= p1)       = False
expmod :: Integer -> Integer -> Integer -> Integer
expmod a x m |  x == 0    = 1
             |  x == 1    = a `mod` m
             |  even x    = let p = (expmod a (x `div` 2) m) `mod` m
                            in  (p^2) `mod` m
             |  otherwise = (a * expmod a (x1) m) `mod` m
intSqrt :: Integer -> Integer
intSqrt i = floor (sqrt (fromIntegral i ) )
(/|) :: Integer -> Integer -> Bool
a /| b = b `mod` a /= 0
primes :: [Integer]
primes = 2:[x | x <- [3,5..], foldr (&&) True 
          ( map ( /| x ) (takeWhile (<=(intSqrt x)) primes ) ) ]