-- | -- Module: Math.NumberTheory.Moduli.Sqrt -- Copyright: (c) 2011 Daniel Fischer -- Licence: MIT -- Maintainer: Andrew Lelechenko -- Stability: Provisional -- Portability: Non-portable (GHC extensions) -- -- Modular square roots. -- {-# LANGUAGE BangPatterns #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE CPP #-} module Math.NumberTheory.Moduli.Sqrt ( -- * New interface sqrtsMod , sqrtsModFactorisation , sqrtsModPrimePower , sqrtsModPrime -- * Old interface , Old.sqrtModP , Old.sqrtModPList , Old.sqrtModP' , Old.tonelliShanks , Old.sqrtModPP , Old.sqrtModPPList , Old.sqrtModF , Old.sqrtModFList ) where import Control.Arrow hiding (loop) import Control.Monad (liftM2) import Data.Bits import Math.NumberTheory.Moduli.Chinese import Math.NumberTheory.Moduli.Class (Mod, getVal, getMod, KnownNat) import Math.NumberTheory.Moduli.Jacobi import Math.NumberTheory.Powers.Modular (powMod) import qualified Math.NumberTheory.Primes.Factorisation as F (factorise) import Math.NumberTheory.Primes.Types import Math.NumberTheory.Primes.Sieve (sieveFrom) import Math.NumberTheory.Utils (shiftToOddCount, splitOff, recipMod) import Math.NumberTheory.Utils.FromIntegral import qualified Math.NumberTheory.Moduli.SqrtOld as Old -- | List all modular square roots. -- -- >>> :set -XDataKinds -- >>> sqrtsMod (1 :: Mod 60) -- > [(1 `modulo` 60),(49 `modulo` 60),(41 `modulo` 60),(29 `modulo` 60),(31 `modulo` 60),(19 `modulo` 60),(11 `modulo` 60),(59 `modulo` 60)] sqrtsMod :: KnownNat m => Mod m -> [Mod m] sqrtsMod a = map fromInteger $ sqrtsModFactorisation (getVal a) (factorise (getMod a)) where factorise = map (PrimeNat . integerToNatural *** intToWord) . F.factorise -- | List all square roots modulo a number, which factorisation is passed as a second argument. -- -- >>> sqrtsModFactorisation 1 (factorise 60) -- [1,49,41,29,31,19,11,59] sqrtsModFactorisation :: Integer -> [(Prime Integer, Word)] -> [Integer] sqrtsModFactorisation _ [] = [0] sqrtsModFactorisation n pps = map fst $ foldl1 (liftM2 comb) cs where ms :: [Integer] ms = map (\(PrimeNat p, pow) -> toInteger p ^ pow) pps rs :: [[Integer]] rs = map (\(p, pow) -> sqrtsModPrimePower n p pow) pps cs :: [[(Integer, Integer)]] cs = zipWith (\l m -> map (\x -> (x, m)) l) rs ms comb t1@(_, m1) t2@(_, m2) = (chineseRemainder2 t1 t2, m1 * m2) -- | List all square roots modulo power of a prime. -- -- >>> sqrtsModPrimePower 7 (fromJust (isPrime 3)) 2 -- [4,5] -- >>> sqrtsModPrimePower 9 (fromJust (isPrime 3)) 3 -- [3,12,21,24,6,15] sqrtsModPrimePower :: Integer -> Prime Integer -> Word -> [Integer] sqrtsModPrimePower nn p 1 = sqrtsModPrime nn p sqrtsModPrimePower nn (PrimeNat (toInteger -> prime)) expo = let primeExpo = prime ^ expo in case splitOff prime (nn `mod` primeExpo) of (_, 0) -> [0, prime ^ ((expo + 1) `quot` 2) .. primeExpo - 1] (kk, n) | odd kk -> [] | otherwise -> case (if prime == 2 then sqM2P n expo' else sqrtModPP' n prime expo') of Nothing -> [] Just r -> let rr = r * prime ^ k in if prime == 2 && k + 1 == t then go rr os else go rr os ++ go (primeExpo - rr) os where k = intToWord kk `quot` 2 t = (if prime == 2 then expo - k - 1 else expo - k) `max` ((expo + 1) `quot` 2) expo' = expo - 2 * k os = [0, prime ^ t .. primeExpo - 1] -- equivalent to map ((`mod` primeExpo) . (+ r)) rs, -- but avoids division go r rs = map (+ r) ps ++ map (+ (r - primeExpo)) qs where (ps, qs) = span (< primeExpo - r) rs -- | List all square roots by prime modulo. -- -- >>> sqrtsModPrime 1 (fromJust (isPrime 5)) -- [1,4] -- >>> sqrtsModPrime 0 (fromJust (isPrime 5)) -- [0] -- >>> sqrtsModPrime 2 (fromJust (isPrime 5)) -- [] sqrtsModPrime :: Integer -> Prime Integer -> [Integer] sqrtsModPrime n (PrimeNat 2) = [n `mod` 2] sqrtsModPrime n (PrimeNat (toInteger -> prime)) = case jacobi n prime of MinusOne -> [] Zero -> [0] One -> let r = sqrtModP' (n `mod` prime) prime in [r, prime - r] ------------------------------------------------------------------------------- -- Internals -- | @sqrtModP' square prime@ finds a square root of @square@ modulo -- prime. @prime@ /must/ be a (positive) prime, and @square@ /must/ be a positive -- quadratic residue modulo @prime@, i.e. @'jacobi square prime == 1@. sqrtModP' :: Integer -> Integer -> Integer sqrtModP' square prime | prime == 2 = square | rem4 prime == 3 = powMod square ((prime + 1) `quot` 4) prime | square `mod` prime == prime - 1 = sqrtOfMinusOne prime | otherwise = tonelliShanks square prime -- | p must be of form 4k + 1 sqrtOfMinusOne :: Integer -> Integer sqrtOfMinusOne p = head $ filter (\n -> n /= 1 && n /= p - 1) $ map (\n -> powMod n k p) [2..p-2] where k = (p - 1) `quot` 4 -- | @tonelliShanks square prime@ calculates a square root of @square@ -- modulo @prime@, where @prime@ is a prime of the form @4*k + 1@ and -- @square@ is a positive quadratic residue modulo @prime@, using the -- Tonelli-Shanks algorithm. tonelliShanks :: Integer -> Integer -> Integer tonelliShanks square prime = loop rc t1 generator log2 where (log2,q) = shiftToOddCount (prime-1) nonSquare = findNonSquare prime generator = powMod nonSquare q prime rc = powMod square ((q+1) `quot` 2) prime t1 = powMod square q prime msqr x = (x*x) `rem` prime msquare 0 x = x msquare k x = msquare (k-1) (msqr x) findPeriod per 1 = per findPeriod per x = findPeriod (per+1) (msqr x) loop :: Integer -> Integer -> Integer -> Int -> Integer loop !r t c m | t == 1 = r | otherwise = loop nextR nextT nextC nextM where nextM = findPeriod 0 t b = msquare (m - 1 - nextM) c nextR = (r*b) `rem` prime nextC = msqr b nextT = (t*nextC) `rem` prime -- | prime must be odd, n must be coprime with prime sqrtModPP' :: Integer -> Integer -> Word -> Maybe Integer sqrtModPP' n prime expo = case sqrtsModPrime n (PrimeNat (fromInteger prime)) of [] -> Nothing r : _ -> fixup r where fixup r = let diff' = r*r-n in if diff' == 0 then Just r else case splitOff prime diff' of (e,q) | expo <= intToWord e -> Just r | otherwise -> fmap (\inv -> hoist inv r (q `mod` prime) (prime^e)) (recipMod (2*r) prime) hoist inv root elim pp | diff' == 0 = root' | expo <= intToWord ex = root' | otherwise = hoist inv root' (nelim `mod` prime) (prime^ex) where root' = (root + (inv*(prime-elim))*pp) `mod` (prime*pp) diff' = root'*root' - n (ex, nelim) = splitOff prime diff' -- dirty, dirty sqM2P :: Integer -> Word -> Maybe Integer sqM2P n (wordToInt -> e) | e < 2 = Just (n `mod` 2) | n' == 0 = Just 0 | odd k = Nothing | otherwise = fmap ((`mod` mdl) . (`shiftL` k2)) $ solve s e2 where mdl = 1 `shiftL` e n' = n `mod` mdl (k, s) = shiftToOddCount n' k2 = k `quot` 2 e2 = e - k solve _ 1 = Just 1 solve 1 _ = Just 1 solve r _ | rem4 r == 3 = Nothing -- otherwise r ≡ 1 (mod 4) | rem8 r == 5 = Nothing -- otherwise r ≡ 1 (mod 8) | otherwise = fixup r (fst $ shiftToOddCount (r-1)) where fixup x pw | pw >= e2 = Just x | otherwise = fixup x' pw' where x' = x + (1 `shiftL` (pw-1)) d = x'*x' - r pw' = if d == 0 then e2 else fst (shiftToOddCount d) ------------------------------------------------------------------------------- -- Utilities rem4 :: Integral a => a -> Int rem4 n = fromIntegral n .&. 3 rem8 :: Integral a => a -> Int rem8 n = fromIntegral n .&. 7 findNonSquare :: Integer -> Integer findNonSquare n | rem8 n == 5 || rem8 n == 3 = 2 | otherwise = search primelist where primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67] ++ sieveFrom (68 + n `rem` 4) -- prevent sharing search (p:ps) = case jacobi p n of MinusOne -> p _ -> search ps search _ = error "Should never have happened, prime list exhausted."