{-# LANGUAGE NoImplicitPrelude, NoMonomorphismRestriction #-}
module Algebra.Algorithms.PrimeTest
       (repeatedSquare, modPow, fermatTest, isPseudoPrime
       ) where
import           Algebra.Prelude.Core     hiding (div, mod)
import           Control.Lens             ((&), (+~), _1)
import           Control.Monad.Random     (MonadRandom, uniform)
import           Data.List                (findIndex)
import           Numeric.Decidable.Zero   (isZero)
import           Numeric.Domain.Euclidean ()
import           Prelude                  (div, mod)
import qualified Prelude                  as P

data PrimeResult = Composite | ProbablyPrime | Prime
                 deriving (Read, Show, Eq, Ord)

-- | Calculates @n@-th power efficiently, using repeated square method.
repeatedSquare :: Multiplicative r => r -> Natural -> r
repeatedSquare a n =
  let bits = tail $ binRep n
  in go a bits
  where
    go b []        = b
    go b (nk : ns) =
      go (if nk == 1 then (b*b*a) else b*b) ns


binRep :: Natural -> [Natural]
binRep = flip go []
  where
    go 0 = id
    go k = go (k `div` 2) . ((k `mod` 2) :)

-- | Fermat-test for pseudo-primeness.
fermatTest :: MonadRandom m => Integer -> m PrimeResult
fermatTest 2 = return Prime
fermatTest n = do
  a <- uniform [2..n - 2]
  let b = modPow n (fromIntegral a) (fromIntegral $ n - 1 :: Natural)
  if b /= 1
    then return Composite
    else return ProbablyPrime

-- | @'modPow' x m p@ efficiently calculates @x ^ p `'mod'` m@.
modPow :: (P.Integral a, Euclidean r) => r -> r -> a -> r
modPow i p = go i one
  where
    go _ acc 0 = acc
    go b acc e = go ((b * b) `rem` p) (if e `mod` 2 == 1 then (acc * b) `rem` p else acc) (e `div` 2)

splitFactor :: Euclidean r => r -> r -> (Int, r)
splitFactor d n =
  let (q,r) = n `divide` d
  in if isZero q
     then (0, n)
     else splitFactor d r & _1 +~ 1

-- | @'isPseudoPrime' n@ tests if the given integer @n@ is pseudo prime.
--   It returns @'Left' p@ if @p < n@ divides @n@,
--   @'Right' 'True'@ if @n@ is pseudo-prime,
--   @'Right' 'False'@ if it is not pseudo-prime but no clue can be found.
isPseudoPrime :: MonadRandom m
              => Integer -> m (Either Integer Bool)
isPseudoPrime 2 = return $ Right True
isPseudoPrime 3 = return $ Right True
isPseudoPrime n = do
  a <- uniform [2..n P.- 2]
  let d = P.gcd a n
  return $ if d > 1
    then Left d
    else
    let (v, m) = splitFactor 2 (n-1)
        b0 = modPow n a m
        bs = take (v+1) $ iterate (\b -> b*b `mod` n) b0
    in if b0 == 1
       then Right True
       else case findIndex (== 1) bs of
         Nothing -> Right False
         Just j ->
           let g = P.gcd (bs !! j - 1) n
           in if g == 1 || g == n then Right True else Left g