-- | -- Module: Math.NumberTheory.Moduli -- Copyright: (c) 2011 Daniel Fischer -- Licence: MIT -- Maintainer: Daniel Fischer -- Stability: Provisional -- Portability: Non-portable (GHC extensions) -- -- Miscellaneous functions related to modular arithmetic. -- {-# LANGUAGE CPP, BangPatterns #-} module Math.NumberTheory.Moduli ( -- * Functions with input check jacobi , invertMod , powerMod , powerModInteger , chineseRemainder -- ** Partially checked input , sqrtModP -- * Unchecked functions , jacobi' , powerMod' , powerModInteger' , sqrtModPList , sqrtModP' , tonelliShanks , sqrtModPP , sqrtModPPList , sqrtModF , sqrtModFList , chineseRemainder2 ) where #include "MachDeps.h" import Data.Word import Data.Bits import Data.Array.Unboxed import Data.Array.Base (unsafeAt) import Data.Maybe (fromJust) import Data.List (nub) import Control.Monad (foldM, liftM2) import Math.NumberTheory.Utils (shiftToOddCount, splitOff) import Math.NumberTheory.GCD (extendedGCD) import Math.NumberTheory.Primes.Heap (sieveFrom) -- Guesstimated startup time for the Heap algorithm is lower than -- the cost to sieve an entire chunk. -- | Invert a number relative to a modulus. -- If @number@ and @modulus@ are coprime, the result is -- @Just inverse@ where -- -- > (number * inverse) `mod` (abs modulus) == 1 -- > 0 <= inverse < abs modulus -- -- unless @modulus == 0@ and @abs number == 1@, in which case the -- result is @Just number@. -- If @gcd number modulus > 1@, the result is @Nothing@. invertMod :: Integer -> Integer -> Maybe Integer invertMod k 0 = if k == 1 || k == (-1) then Just k else Nothing invertMod k m = wrap $ go False 1 0 m' k' where m' = abs m k' | r < 0 = r+m' | otherwise = r where r = k `rem` m' wrap x = case (x*k') `rem` m' of 1 -> Just x _ -> Nothing -- Calculate modular inverse of k' modulo m' by continued fraction expansion -- of m'/k', say [a_0,a_1,...,a_s]. Let the convergents be p_j/q_j. -- Starting from j = -2, the arguments of go are -- (p_j/q_j) > m'/k', p_{j+1}, p_j, and n, d with n/d = [a_{j+2},...,a_s]. -- Since m'/k' = p_s/q_s, and p_j*q_{j+1} - p_{j+1}*q_j = (-1)^(j+1), we have -- p_{s-1}*k' - q_{s-1}*m' = (-1)^s * gcd m' k', so if the inverse exists, -- it is either p_{s-1} or -p_{s-1}, depending on whether s is even or odd. go !b _ po _ 0 = if b then po else (m'-po) go b !pn po n d = case n `quotRem` d of (q,r) -> go (not b) (q*pn+po) pn d r -- | Jacobi symbol of two numbers. -- The \"denominator\" must be odd and positive, this condition is checked. -- -- If both numbers have a common prime factor, the result -- is @0@, otherwise it is ±1. {-# SPECIALISE jacobi :: Integer -> Integer -> Int, Int -> Int -> Int, Word -> Word -> Int #-} jacobi :: (Integral a, Bits a) => a -> a -> Int jacobi a b | b < 0 = error "Math.NumberTheory.Moduli.jacobi: negative denominator" | evenI b = error "Math.NumberTheory.Moduli.jacobi: even denominator" | b == 1 = 1 | a == 0 = 0 | a == 1 = 1 | otherwise = jacobi' a b -- b odd, > 1, a neither 0 or 1 -- Invariant: b > 1 and odd -- | Jacobi symbol of two numbers without validity check of -- the \"denominator\". {-# SPECIALISE jacobi' :: Integer -> Integer -> Int, Int -> Int -> Int, Word -> Word -> Int #-} jacobi' :: (Integral a, Bits a) => a -> a -> Int jacobi' a b | a == 0 = 0 | a == 1 = 1 | a < 0 = let n | rem4 b == 1 = 1 | otherwise = -1 -- Blech, minBound may pose problems (z,o) = shiftToOddCount (abs $ toInteger a) s | evenI z || unsafeAt jac2 (rem8 b) == 1 = n | otherwise = (-n) in s*jacobi' (fromInteger o) b | a >= b = case a `rem` b of 0 -> 0 r -> jacPS 1 r b | evenI a = case shiftToOddCount a of (z,o) -> let r = 2 - (rem4 o .&. rem4 b) s | evenI z || unsafeAt jac2 (rem8 b) == 1 = r | otherwise = (-r) in jacOL s b o | otherwise = case rem4 a .&. rem4 b of 3 -> jacOL (-1) b a _ -> jacOL 1 b a -- numerator positive and smaller than denominator {-# SPECIALISE jacPS :: Int -> Integer -> Integer -> Int, Int -> Int -> Int -> Int, Int -> Word -> Word -> Int #-} jacPS :: (Integral a, Bits a) => Int -> a -> a -> Int jacPS !j a b | evenI a = case shiftToOddCount a of (z,o) | evenI z || unsafeAt jac2 (rem8 b) == 1 -> jacOL (if rem4 o .&. rem4 b == 3 then (-j) else j) b o | otherwise -> jacOL (if rem4 o .&. rem4 b == 3 then j else (-j)) b o | otherwise = jacOL (if rem4 a .&. rem4 b == 3 then (-j) else j) b a -- numerator odd, positive and larger than denominator {-# SPECIALISE jacOL :: Int -> Integer -> Integer -> Int, Int -> Int -> Int -> Int, Int -> Word -> Word -> Int #-} jacOL :: (Integral a, Bits a) => Int -> a -> a -> Int jacOL !j a b | b == 1 = j | otherwise = case a `rem` b of 0 -> 0 r -> jacPS j r b -- | Modular power. -- -- > powerMod base exponent modulus -- -- calculates @(base ^ exponent) \`mod\` modulus@ by repeated squaring and reduction. -- If @exponent < 0@ and @base@ is invertible modulo @modulus@, @(inverse ^ |exponent|) \`mod\` modulus@ -- is calculated. This function does some input checking and sanitation before calling the unsafe worker. {-# RULES "powerMod/Integer" powerMod = powerModInteger #-} {-# INLINE [1] powerMod #-} powerMod :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer powerMod = powerModImpl {-# SPECIALISE powerModImpl :: Integer -> Int -> Integer -> Integer, Integer -> Word -> Integer -> Integer #-} powerModImpl :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer powerModImpl base expo md | md == 0 = base ^ expo | md' == 1 = 0 | expo == 0 = 1 | bse' == 1 = 1 | expo < 0 = case invertMod bse' md' of Just i -> powerMod'Impl i (negate expo) md' Nothing -> error "Math.NumberTheory.Moduli.powerMod: Base isn't invertible with respect to modulus" | bse' == 0 = 0 | otherwise = powerMod'Impl bse' expo md' where md' = abs md bse' = if base < 0 || md' <= base then base `mod` md' else base -- | Modular power worker without input checking. -- Assumes all arguments strictly positive and modulus greater than 1. {-# RULES "powerMod'/Integer" powerMod' = powerModInteger' #-} {-# INLINE [1] powerMod' #-} powerMod' :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer powerMod' = powerMod'Impl {-# SPECIALISE powerMod'Impl :: Integer -> Int -> Integer -> Integer, Integer -> Word -> Integer -> Integer #-} powerMod'Impl :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer powerMod'Impl base expo md = go expo 1 base where go 1 !a !s = (a*s) `rem` md go e a s | testBit e 0 = go (e `shiftR` 1) ((a*s) `rem` md) ((s*s) `rem` md) | otherwise = go (e `shiftR` 1) a ((s*s) `rem` md) -- | Specialised version of 'powerMod' for 'Integer' exponents. -- Reduces the number of shifts of the exponent since shifting -- large 'Integer's is expensive. Call this function directly -- if you don't want or can't rely on rewrite rules. powerModInteger :: Integer -> Integer -> Integer -> Integer powerModInteger base ex mdl | mdl == 0 = base ^ ex | mdl' == 1 = 0 | ex == 0 = 1 | ex < 0 = case invertMod bse' mdl' of Just i -> powerModInteger' i (negate ex) mdl' Nothing -> error "Math.NumberTheory.Moduli.powerMod: Base isn't invertible with respect to modulus" | bse' == 0 = 0 | bse' == 1 = 1 | otherwise = powerModInteger' bse' ex mdl' where mdl' = abs mdl bse' = if base < 0 || mdl' <= base then base `mod` mdl' else base -- | Specialised worker without input checks. Makes the same assumptions -- as the general version 'powerMod''. powerModInteger' :: Integer -> Integer -> Integer -> Integer powerModInteger' base expo md = go w1 1 base e1 where w1 = fromInteger expo e1 = expo `shiftR` 64 #if WORD_SIZE_IN_BITS == 32 -- Shifting large Integers is expensive, hence we reduce the -- number of shifts by processing in 64-bit chunks. -- On 32-bit systems, every testBit on a Word64 would be a C-call, -- thus it is faster to split each Word64 into the constituent 32-bit -- Words and process those separately. -- The code becomes ugly, unfortunately. go :: Word64 -> Integer -> Integer -> Integer -> Integer go !w !a !s 0 = end a s w go w a s e = inner1 a s 0 where wl :: Word !wl = fromIntegral w wh :: Word !wh = fromIntegral (w `shiftR` 32) inner1 !au !sq 32 = inner2 au sq 0 inner1 au sq i | testBit wl i = inner1 ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1) | otherwise = inner1 au ((sq*sq) `rem` md) (i+1) inner2 !au !sq 32 = go (fromInteger e) au sq (e `shiftR` 64) inner2 au sq i | testBit wh i = inner2 ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1) | otherwise = inner2 au ((sq*sq) `rem` md) (i+1) end !a !s w | wh == 0 = fin a s wl | otherwise = innerE a s 0 where wl :: Word !wl = fromIntegral w wh :: Word !wh = fromIntegral (w `shiftR` 32) innerE !au !sq 32 = fin au sq wh innerE au sq i | testBit wl i = innerE ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1) | otherwise = innerE au ((sq*sq) `rem` md) (i+1) fin :: Integer -> Integer -> Word -> Integer fin !a !s 1 = (a*s) `rem` md fin a s w | testBit w 0 = fin ((a*s) `rem` md) ((s*s) `rem` md) (w `shiftR` 1) | otherwise = fin a ((s*s) `rem` md) (w `shiftR` 1) #else -- WORD_SIZE_IN_BITS == 64, otherwise things wouldn't compile anyway -- Shorter code since we need not split each 64-bit word. go :: Word -> Integer -> Integer -> Integer -> Integer go !w !a !s 0 = end a s w go w a s e = inner a s 0 where inner !au !sq 64 = go (fromInteger e) au sq (e `shiftR` 64) inner au sq i | testBit w i = inner ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1) | otherwise = inner au ((sq*sq) `rem` md) (i+1) end !a !s 1 = (a*s) `rem` md end a s w | testBit w 0 = end ((a*s) `rem` md) ((s*s) `rem` md) (w `shiftR` 1) | otherwise = end a ((s*s) `rem` md) (w `shiftR` 1) #endif -- | @sqrtModP n prime@ calculates a modular square root of @n@ modulo @prime@ -- if that exists. The second argument /must/ be a (positive) prime, otherwise -- the computation may not terminate and if it does, may yield a wrong result. -- The precondition is /not/ checked. -- -- If @prime@ is a prime and @n@ a quadratic residue modulo @prime@, the result -- is @Just r@ where @r^2 ≡ n (mod prime)@, if @n@ is a quadratic nonresidue, -- the result is @Nothing@. sqrtModP :: Integer -> Integer -> Maybe Integer sqrtModP n 2 = Just (n `mod` 2) sqrtModP n prime = case jacobi' n prime of 0 -> Just 0 1 -> Just (sqrtModP' (n `mod` prime) prime) _ -> Nothing -- | @sqrtModPList n prime@ computes the list of all square roots of @n@ -- modulo @prime@. @prime@ /must/ be a (positive) prime. -- The precondition is /not/ checked. sqrtModPList :: Integer -> Integer -> [Integer] sqrtModPList n prime | prime == 2 = [n `mod` 2] | otherwise = case sqrtModP n prime of Just 0 -> [0] Just r -> [r,prime-r] -- The group of units in Z/(p) is cyclic _ -> [] -- | @sqrtModP' square prime@ finds a square root of @square@ modulo -- prime. @prime@ /must/ be a (positive) prime, and @sqaure@ /must/ be a -- quadratic residue modulo @prime@, i.e. @'jacobi square prime == 1@. -- The precondition is /not/ checked. sqrtModP' :: Integer -> Integer -> Integer sqrtModP' square prime | prime == 2 = square | rem4 prime == 3 = powerModInteger' square ((prime + 1) `quot` 4) prime | otherwise = tonelliShanks square prime -- | @tonelliShanks square prime@ calculates a square root of @square@ -- modulo @prime@, where @prime@ is a prime of the form @4*k + 1@ and -- @square@ is a quadratic residue modulo @prime@, using the -- Tonelli-Shanks algorithm. -- No checks on the input are performed. tonelliShanks :: Integer -> Integer -> Integer tonelliShanks square prime = loop rc t1 generator log2 where (log2,q) = shiftToOddCount (prime-1) nonSquare = findNonSquare prime generator = powerModInteger' nonSquare q prime rc = powerModInteger' square ((q+1) `quot` 2) prime t1 = powerModInteger' square q prime msqr x = (x*x) `rem` prime msquare 0 x = x msquare k x = msquare (k-1) (msqr x) findPeriod per 1 = per findPeriod per x = findPeriod (per+1) (msqr x) loop !r t c m | t == 1 = r | otherwise = loop nextR nextT nextC nextM where nextM = findPeriod 0 t b = msquare (m - 1 - nextM) c nextR = (r*b) `rem` prime nextC = msqr b nextT = (t*nextC) `rem` prime -- | @sqrtModPP n (prime,expo)@ calculates a square root of @n@ -- modulo @prime^expo@ if one exists. @prime@ /must/ be a -- (positive) prime. @expo@ must be positive, @n@ must be coprime -- to @prime@ sqrtModPP :: Integer -> (Integer,Int) -> Maybe Integer sqrtModPP n (2,e) = sqM2P n e sqrtModPP n (prime,expo) = case sqrtModP n prime of Just r -> Just $ fixup r _ -> Nothing where fixup r = let diff' = r*r-n in if diff' == 0 then r else case splitOff prime diff' of (e,q) | expo <= e -> r | otherwise -> hoist (fromJust $ invertMod (2*r) prime) r (q `mod` prime) (prime^e) -- hoist inv root elim pp | diff' == 0 = root' | expo <= ex = root' | otherwise = hoist inv root' (nelim `mod` prime) (prime^ex) where root' = (root + (inv*(prime-elim))*pp) `mod` (prime*pp) diff' = root'*root' - n (ex, nelim) = splitOff prime diff' -- dirty, dirty sqM2P :: Integer -> Int -> Maybe Integer sqM2P n e | e < 2 = Just (n `mod` 2) | n' == 0 = Just 0 | e <= k = Just 0 | odd k = Nothing | otherwise = fmap ((`mod` mdl) . (`shiftL` k2)) $ solve s e2 where mdl = 1 `shiftL` e n' = n `mod` mdl (k,s) = shiftToOddCount n' k2 = k `quot` 2 e2 = e-k solve _ 1 = Just 1 solve 1 _ = Just 1 solve r p | rem4 r == 3 = Nothing -- otherwise r ≡ 1 (mod 4) | p == 2 = Just 1 -- otherwise p >= 3 | rem8 r == 5 = Nothing -- otherwise r ≡ 1 (mod 8) | otherwise = fixup r (fst $ shiftToOddCount (r-1)) where fixup x pw | pw >= e2 = Just x | otherwise = fixup x' pw' where x' = x + (1 `shiftL` (pw-1)) d = x'*x' - r pw' = if d == 0 then e2 else fst (shiftToOddCount d) -- | @sqrtModF n primePowers@ calculates a square root of @n@ modulo -- @product [p^k | (p,k) <- primePowers]@ if one exists and all primes -- are distinct. sqrtModF :: Integer -> [(Integer,Int)] -> Maybe Integer sqrtModF n pps = do roots <- mapM (sqrtModPP n) pps chineseRemainder $ zip roots (map (uncurry (^)) pps) -- | @sqrtModFList n primePowers@ calculates all square roots of @n@ modulo -- @product [p^k | (p,k) <- primePowers]@ if all primes are distinct. sqrtModFList :: Integer -> [(Integer,Int)] -> [Integer] sqrtModFList n pps = map fst $ foldl1 (liftM2 comb) cs where ms :: [Integer] ms = map (uncurry (^)) pps rs :: [[Integer]] rs = map (sqrtModPPList n) pps cs :: [[(Integer,Integer)]] cs = zipWith (\l m -> map (\x -> (x,m)) l) rs ms comb t1@(_,m1) t2@(_,m2) = (chineseRemainder2 t1 t2,m1*m2) -- | @sqrtModPPList n (prime,expo)@ calculates the list of all -- square roots of @n@ modulo @prime^expo@. The same restriction -- as in 'sqrtModPP' applies to the arguments. sqrtModPPList :: Integer -> (Integer,Int) -> [Integer] sqrtModPPList n (2,expo) = case sqM2P n expo of Just r -> let m = 1 `shiftL` (expo-1) in nub [r, (r+m) `mod` (2*m), (m-r) `mod` (2*m), 2*m-r] _ -> [] sqrtModPPList n pe@(prime,expo) = case sqrtModPP n pe of Just 0 -> [0] Just r -> [prime^expo - r, r] -- The group of units in Z/(p^e) is cyclic _ -> [] -- | Given a list @[(r_1,m_1), ..., (r_n,m_n)]@ of @(residue,modulus)@ -- pairs, @chineseRemainder@ calculates the solution to the simultaneous -- congruences -- -- > -- > r ≡ r_k (mod m_k) -- > -- -- if all moduli are pairwise coprime. If not all moduli are -- pairwise coprime, the result is @Nothing@ regardless of whether -- a solution exists. chineseRemainder :: [(Integer,Integer)] -> Maybe Integer chineseRemainder remainders = foldM addRem 0 remainders where !modulus = product (map snd remainders) addRem acc (r,m) = do let cf = modulus `quot` m inv <- invertMod cf m Just $! (acc + inv*cf*r) `mod` modulus -- | @chineseRemainder2 (r_1,m_1) (r_2,m_2)@ calculates the solution of -- -- > -- > r ≡ r_k (mod m_k) -- -- if @m_1@ and @m_2@ are coprime. chineseRemainder2 :: (Integer,Integer) -> (Integer,Integer) -> Integer chineseRemainder2 (r1, md1) (r2,md2) = case extendedGCD md1 md2 of (_,u,v) -> ((1 - u*md1)*r1 + (1 - v*md2)*r2) `mod` (md1*md2) -- Utilities -- For large Integers, going via Int is much faster than bit-fiddling -- on the Integer, so we do that. {-# SPECIALISE evenI :: Integer -> Bool, Int -> Bool, Word -> Bool #-} evenI :: (Integral a, Bits a) => a -> Bool evenI n = fromIntegral n .&. 1 == (0 :: Int) {-# SPECIALISE rem4 :: Integer -> Int, Int -> Int, Word -> Int #-} rem4 :: (Integral a, Bits a) => a -> Int rem4 n = fromIntegral n .&. 3 {-# SPECIALISE rem8 :: Integer -> Int, Int -> Int, Word -> Int #-} rem8 :: (Integral a, Bits a) => a -> Int rem8 n = fromIntegral n .&. 7 jac2 :: UArray Int Int jac2 = array (0,7) [(0,0),(1,1),(2,0),(3,-1),(4,0),(5,-1),(6,0),(7,1)] findNonSquare :: Integer -> Integer findNonSquare n | rem8 n == 5 || rem8 n == 3 = 2 | otherwise = search primelist where primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67] ++ sieveFrom (68 + n `rem` 4) -- prevent sharing search (p:ps) | jacobi' p n == -1 = p | otherwise = search ps search _ = error "Should never have happened, prime list exhausted."