-- Copyright (c) 2006-2011, David Amos. All rights reserved.

{-# LANGUAGE NoMonomorphismRestriction #-}

-- |A module providing functions to test for primality, and find next and previous primes.
module Math.NumberTheory.Prime (primes, isTrialDivisionPrime, isMillerRabinPrime,
                                isPrime, notPrime, prevPrime, nextPrime) where

import System.Random
import System.IO.Unsafe


isTrialDivisionPrime n
    | n > 1 = isNotDivisibleBy primes
    | otherwise = False
    where isNotDivisibleBy (d:ds) | d*d > n         = True
                                  | n `rem` d == 0  = False
                                  | otherwise       = isNotDivisibleBy ds

-- |A (lazy) list of the primes
primes :: [Integer]
primes = 2 : 3 : filter isTrialDivisionPrime (concat [ [m6-1,m6+1] | m6 <- [6,12..] ])

-- initial version. This isn't going to be very good if n has any "large" prime factors (eg > 10000)
pfactors1 n | n > 0 = pfactors' n primes
            | n < 0 = -1 : pfactors' (-n) primes
    where pfactors' n (d:ds) | n == 1 = []
                             | n < d*d = [n]
                             | r == 0 = d : pfactors' q (d:ds)
                             | otherwise = pfactors' n ds
                             where (q,r) = quotRem n d


-- MILLER-RABIN TEST
-- Cohen, A Course in Computational Algebraic Number Theory, p422
-- Koblitz, A Course in Number Theory and Cryptography


-- Let n-1 = 2^s * q, q odd
-- Then n is a strong pseudoprime to base b if
-- either b^q == 1 (mod n)
-- or b^(2^r * q) == -1 (mod n) for some 0 <= r < s
-- (For we know that if n is prime, then b^(n-1) == 1 (mod n)

isStrongPseudoPrime n b =
    let (s,q) = split2s 0 (n-1)  -- n-1 == 2^s * q, with q odd
    in isStrongPseudoPrime' n (s,q) b

isStrongPseudoPrime' n (s,q) b
    | b' == 1 = True
    | otherwise = n-1 `elem` squarings
    where b' = power_mod b q n     -- b' = b^q `mod` n
          squarings = take s $ iterate (\x -> x*x `mod` n) b' -- b^(2^r *q) for 0 <= r < s

-- split2s 0 m returns (s,t) such that 2^s * t == m, t odd
split2s s t = let (q,r) = t `quotRem` 2
              in if r == 0 then split2s (s+1) q else (s,t)

-- power_mod b t n == b^t mod n
power_mod b t n = powerMod' b 1 t
    where powerMod' x y 0 = y
          powerMod' x y t = powerMod' (x*x `rem` n) (if even t then y else x*y `rem` n) (t `div` 2)

isMillerRabinPrime' n
    | n >= 4 =
        let (s,q) = split2s 0 (n-1) -- n-1 == 2^s * q, with q odd
        in do g <- getStdGen
              let rs = randomRs (2,n-1) g
              return $ all (isStrongPseudoPrime' n (s,q)) (take 25 rs)
    | n >= 2 = return True
    | otherwise = return False
-- Cohen states that if we restrict our rs to single word numbers, we can use a more efficient powering algorithm

-- isMillerRabinPrime :: Integer -> Bool
isMillerRabinPrime n = unsafePerformIO (isMillerRabinPrime' n)


-- |Is this number prime? The algorithm consists of using trial division to test for very small factors,
-- followed if necessary by the Miller-Rabin probabilistic test.
isPrime :: Integer -> Bool
isPrime n | n > 1 = isPrime' $ takeWhile (< 100) primes
          | otherwise = False
    where isPrime' (d:ds) | n < d*d = True
                          | otherwise = let (q,r) = quotRem n d
                                        in if r == 0 then False else isPrime' ds
          isPrime' [] = isMillerRabinPrime n
-- the < 100 is found heuristically to be about the point at which trial division stops being worthwhile

notPrime :: Integer -> Bool
notPrime = not . isPrime

-- |Given n, @prevPrime n@ returns the greatest p, p < n, such that p is prime
prevPrime :: Integer -> Integer
prevPrime n | n > 5 = head $ filter isPrime $ candidates
            | n < 3 = error "prevPrime: no previous primes"
            | n == 3 = 2
            | otherwise = 3
            where n6 = (n `div` 6) * 6
                  candidates = dropWhile (>= n) $ concat [ [m6+5,m6+1] | m6 <- [n6, n6-6..] ]

-- |Given n, @nextPrime n@ returns the least p, p > n, such that p is prime
nextPrime :: Integer -> Integer
nextPrime n | n < 2 = 2
            | n < 3 = 3
            | otherwise = head $ filter isPrime $ candidates
            where n6 = (n `div` 6) * 6
                  candidates = dropWhile (<= n) $ concat [ [m6+1,m6+5] | m6 <- [n6, n6+6..] ]

-- slightly better version. This is okay so long as n has at most one "large" prime factor (> 10000)
-- if it has more, it does at least tell you, via an error message, that it has run into difficulties
pfactors2 n | n > 0 = pfactors' n $ takeWhile (< 10000) primes
            | n < 0 = -1 : pfactors' (-n) (takeWhile (< 10000) primes)
    where pfactors' n (d:ds) | n == 1 = []
                             | n < d*d = [n]
                             | r == 0 = d : pfactors' q (d:ds)
                             | otherwise = pfactors' n ds
                             where (q,r) = quotRem n d
          pfactors' n [] = if isMillerRabinPrime n then [n] else error ("pfactors2: can't factor " ++ show n)