{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
#if __GLASGOW_HASKELL__ < 801
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Math.NumberTheory.Moduli.DiscreteLogarithm
( discreteLogarithm
) where
import qualified Data.IntMap.Strict as M
import Data.Maybe (maybeToList)
import Data.Proxy
import Numeric.Natural (Natural)
import GHC.Integer.GMP.Internals (recipModInteger, powModInteger)
import GHC.TypeNats.Compat
import Math.NumberTheory.Moduli.Chinese (chineseRemainder2)
import Math.NumberTheory.Moduli.Class (MultMod(..), Mod, getVal)
import Math.NumberTheory.Moduli.Equations (solveLinear)
import Math.NumberTheory.Moduli.PrimitiveRoot (PrimitiveRoot(..))
import Math.NumberTheory.Moduli.Singleton
import Math.NumberTheory.Powers.Squares (integerSquareRoot)
import Math.NumberTheory.Primes (unPrime)
discreteLogarithm :: CyclicGroup Integer m -> PrimitiveRoot m -> MultMod m -> Natural
discreteLogarithm cg (multElement . unPrimitiveRoot -> a) (multElement -> b) = case cg of
CG2
-> 0
CG4
-> if getVal b == 1 then 0 else 1
CGOddPrimePower (unPrime -> p) k
-> discreteLogarithmPP p k (getVal a) (getVal b)
CGDoubleOddPrimePower (unPrime -> p) k
-> discreteLogarithmPP p k (getVal a `rem` p^k) (getVal b `rem` p^k)
{-# INLINE discreteLogarithmPP #-}
discreteLogarithmPP :: Integer -> Word -> Integer -> Integer -> Natural
discreteLogarithmPP p 1 a b = discreteLogarithmPrime p a b
discreteLogarithmPP p k a b = fromInteger result
where
baseSol = toInteger $ discreteLogarithmPrime p (a `rem` p) (b `rem` p)
thetaA = theta p pkMinusOne a
thetaB = theta p pkMinusOne b
pkMinusOne = p^(k-1)
c = (recipModInteger thetaA pkMinusOne * thetaB) `rem` pkMinusOne
result = chineseRemainder2 (baseSol, p-1) (c, pkMinusOne)
{-# INLINE theta #-}
theta :: Integer -> Integer -> Integer -> Integer
theta p pkMinusOne a = (numerator `quot` pk) `rem` pkMinusOne
where
pk = pkMinusOne * p
p2kMinusOne = pkMinusOne * pk
numerator = (powModInteger a (pk - pkMinusOne) p2kMinusOne - 1) `rem` p2kMinusOne
discreteLogarithmPrime :: Integer -> Integer -> Integer -> Natural
discreteLogarithmPrime p a b
| p < 100000000 = fromIntegral $ discreteLogarithmPrimeBSGS (fromInteger p) (fromInteger a) (fromInteger b)
| otherwise = discreteLogarithmPrimePollard p a b
discreteLogarithmPrimeBSGS :: Int -> Int -> Int -> Int
discreteLogarithmPrimeBSGS p a b = head [i*m + j | (v,i) <- zip giants [0..m-1], j <- maybeToList (M.lookup v table)]
where
m = integerSquareRoot (p - 2) + 1
babies = iterate (.* a) 1
table = M.fromList (zip babies [0..m-1])
aInv = recipModInteger (toInteger a) (toInteger p)
bigGiant = fromInteger $ powModInteger aInv (toInteger m) (toInteger p)
giants = iterate (.* bigGiant) b
x .* y = x * y `rem` p
discreteLogarithmPrimePollard :: Integer -> Integer -> Integer -> Natural
discreteLogarithmPrimePollard p a b =
case concatMap runPollard [(x,y) | x <- [0..n], y <- [0..n]] of
(t:_) -> fromInteger t
[] -> error ("discreteLogarithm: pollard's rho failed, please report this as a bug. inputs " ++ show [p,a,b])
where
n = p-1
halfN = n `quot` 2
mul2 m = if m < halfN then m * 2 else m * 2 - n
sqrtN = integerSquareRoot n
step (xi,!ai,!bi) = case xi `rem` 3 of
0 -> (xi*xi `rem` p, mul2 ai, mul2 bi)
1 -> ( a*xi `rem` p, ai+1, bi)
_ -> ( b*xi `rem` p, ai, bi+1)
initialise (x,y) = (powModInteger a x n * powModInteger b y n `rem` n, x, y)
begin t = go (step t) (step (step t))
check t = powModInteger a t p == b
go tort@(xi,ai,bi) hare@(x2i,a2i,b2i)
| xi == x2i, gcd (bi - b2i) n < sqrtN = case someNatVal (fromInteger n) of
SomeNat (Proxy :: Proxy n) -> map getVal $ solveLinear (fromInteger (bi - b2i) :: Mod n) (fromInteger (ai - a2i))
| xi == x2i = []
| otherwise = go (step tort) (step (step hare))
runPollard = filter check . begin . initialise