module Arithmetic.Prime
where
import OpenTheory.Primitive.Natural
import OpenTheory.Primitive.Random as Random
import OpenTheory.Natural
import qualified OpenTheory.Natural.Uniform as Uniform
import Arithmetic.Random
import Arithmetic.Utility
import qualified Arithmetic.Modular as Modular
import qualified Arithmetic.Prime.Sieve as Sieve
primes :: [Natural]
primes = 2 : Sieve.advance 1 4 Sieve.initial
millerRabinWitness :: Natural -> Natural -> Bool
millerRabinWitness n =
    \a -> witness (Modular.exp n a s) r
  where
    witness x i =
        if i == 0 then x /= 1
        else if x2 == 1 then not (x == 1 || x == n1)
        else witness x2 (i - 1)
      where
        x2 = Modular.square n x
    (r,s) = factorTwos n1
    n1 = n - 1
millerRabin :: Natural -> Natural -> Random.Random -> Bool
millerRabin t n =
    \r -> n == 2 || n == 3 || (n /= 1 && naturalOdd n && trials t r)
  where
    trials i r =
        i == 0 || (trial r1 && trials (i - 1) r2)
      where
        (r1,r2) = Random.split r
    trial = not . millerRabinWitness n . range
    range r = Uniform.random (n - 3) r + 2
isPrime :: Natural -> Random.Random -> Bool
isPrime = millerRabin 100
previousPrime :: Natural -> Random.Random -> Natural
previousPrime n r =
    if isPrime n r1 then n else previousPrime (n - 1) r2
  where
    (r1,r2) = Random.split r
nextPrime :: Natural -> Random.Random -> Natural
nextPrime n r =
    if isPrime n r1 then n else nextPrime (n + 1) r2
  where
    (r1,r2) = Random.split r
nextPrime3Mod4 :: Natural -> Random.Random -> Natural
nextPrime3Mod4 =
    \n -> go ((4 * (n `div` 4)) + 3)
  where
    go n r =
        if isPrime n r1 then n else go (n + 4) r2
      where
        (r1,r2) = Random.split r
nextPrime5Mod8 :: Natural -> Random.Random -> Natural
nextPrime5Mod8 =
    \n -> go ((8 * ((n + 2) `div` 8)) + 5)
  where
    go n r =
        if isPrime n r1 then n else go (n + 8) r2
      where
        (r1,r2) = Random.split r
randomPrime :: Natural -> Random.Random -> Natural
randomPrime w =
    randomMaybe gen
  where
    gen r =
        if isPrime n r2 then Just n else Nothing
      where
        n = randomOdd w r1
        (r1,r2) = Random.split r
randomPrime3Mod4 :: Natural -> Random.Random -> Natural
randomPrime3Mod4 w =
    randomFilter check (randomPrime w)
  where
    check p = p `mod` 4 == 3
randomPrime5Mod8 :: Natural -> Random.Random -> Natural
randomPrime5Mod8 w =
    randomFilter check (randomPrime w)
  where
    check p = p `mod` 8 == 5