{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE CPP #-}
module Math.NumberTheory.Moduli.Sqrt
(
sqrtsMod
, sqrtsModFactorisation
, sqrtsModPrimePower
, sqrtsModPrime
, Old.sqrtModP
, Old.sqrtModPList
, Old.sqrtModP'
, Old.tonelliShanks
, Old.sqrtModPP
, Old.sqrtModPPList
, Old.sqrtModF
, Old.sqrtModFList
) where
import Control.Arrow hiding (loop)
import Control.Monad (liftM2)
import Data.Bits
import Math.NumberTheory.Moduli.Chinese
import Math.NumberTheory.Moduli.Class (Mod, getVal, getMod, KnownNat)
import Math.NumberTheory.Moduli.Jacobi
import Math.NumberTheory.Powers.Modular (powMod)
import qualified Math.NumberTheory.Primes.Factorisation as F (factorise)
import Math.NumberTheory.Primes.Types
import Math.NumberTheory.Primes.Sieve (sieveFrom)
import Math.NumberTheory.Utils (shiftToOddCount, splitOff, recipMod)
import Math.NumberTheory.Utils.FromIntegral
import qualified Math.NumberTheory.Moduli.SqrtOld as Old
sqrtsMod :: KnownNat m => Mod m -> [Mod m]
sqrtsMod a = map fromInteger $ sqrtsModFactorisation (getVal a) (factorise (getMod a))
where
factorise = map (PrimeNat . integerToNatural *** intToWord) . F.factorise
sqrtsModFactorisation :: Integer -> [(Prime Integer, Word)] -> [Integer]
sqrtsModFactorisation _ [] = [0]
sqrtsModFactorisation n pps = map fst $ foldl1 (liftM2 comb) cs
where
ms :: [Integer]
ms = map (\(PrimeNat p, pow) -> toInteger p ^ pow) pps
rs :: [[Integer]]
rs = map (\(p, pow) -> sqrtsModPrimePower n p pow) pps
cs :: [[(Integer, Integer)]]
cs = zipWith (\l m -> map (\x -> (x, m)) l) rs ms
comb t1@(_, m1) t2@(_, m2) = (chineseRemainder2 t1 t2, m1 * m2)
sqrtsModPrimePower :: Integer -> Prime Integer -> Word -> [Integer]
sqrtsModPrimePower nn p 1 = sqrtsModPrime nn p
sqrtsModPrimePower nn (PrimeNat (toInteger -> prime)) expo = let primeExpo = prime ^ expo in
case splitOff prime (nn `mod` primeExpo) of
(_, 0) -> [0, prime ^ ((expo + 1) `quot` 2) .. primeExpo - 1]
(kk, n)
| odd kk -> []
| otherwise -> case (if prime == 2 then sqM2P n expo' else sqrtModPP' n prime expo') of
Nothing -> []
Just r -> let rr = r * prime ^ k in
if prime == 2 && k + 1 == t
then go rr os
else go rr os ++ go (primeExpo - rr) os
where
k = intToWord kk `quot` 2
t = (if prime == 2 then expo - k - 1 else expo - k) `max` ((expo + 1) `quot` 2)
expo' = expo - 2 * k
os = [0, prime ^ t .. primeExpo - 1]
go r rs = map (+ r) ps ++ map (+ (r - primeExpo)) qs
where
(ps, qs) = span (< primeExpo - r) rs
sqrtsModPrime :: Integer -> Prime Integer -> [Integer]
sqrtsModPrime n (PrimeNat 2) = [n `mod` 2]
sqrtsModPrime n (PrimeNat (toInteger -> prime)) = case jacobi n prime of
MinusOne -> []
Zero -> [0]
One -> let r = sqrtModP' (n `mod` prime) prime in [r, prime - r]
sqrtModP' :: Integer -> Integer -> Integer
sqrtModP' square prime
| prime == 2 = square
| rem4 prime == 3 = powMod square ((prime + 1) `quot` 4) prime
| square `mod` prime == prime - 1
= sqrtOfMinusOne prime
| otherwise = tonelliShanks square prime
sqrtOfMinusOne :: Integer -> Integer
sqrtOfMinusOne p
= head
$ filter (\n -> n /= 1 && n /= p - 1)
$ map (\n -> powMod n k p)
[2..p-2]
where
k = (p - 1) `quot` 4
tonelliShanks :: Integer -> Integer -> Integer
tonelliShanks square prime = loop rc t1 generator log2
where
(log2,q) = shiftToOddCount (prime-1)
nonSquare = findNonSquare prime
generator = powMod nonSquare q prime
rc = powMod square ((q+1) `quot` 2) prime
t1 = powMod square q prime
msqr x = (x*x) `rem` prime
msquare 0 x = x
msquare k x = msquare (k-1) (msqr x)
findPeriod per 1 = per
findPeriod per x = findPeriod (per+1) (msqr x)
loop :: Integer -> Integer -> Integer -> Int -> Integer
loop !r t c m
| t == 1 = r
| otherwise = loop nextR nextT nextC nextM
where
nextM = findPeriod 0 t
b = msquare (m - 1 - nextM) c
nextR = (r*b) `rem` prime
nextC = msqr b
nextT = (t*nextC) `rem` prime
sqrtModPP' :: Integer -> Integer -> Word -> Maybe Integer
sqrtModPP' n prime expo = case sqrtsModPrime n (PrimeNat (fromInteger prime)) of
[] -> Nothing
r : _ -> fixup r
where
fixup r = let diff' = r*r-n
in if diff' == 0
then Just r
else case splitOff prime diff' of
(e,q) | expo <= intToWord e -> Just r
| otherwise -> fmap (\inv -> hoist inv r (q `mod` prime) (prime^e)) (recipMod (2*r) prime)
hoist inv root elim pp
| diff' == 0 = root'
| expo <= intToWord ex = root'
| otherwise = hoist inv root' (nelim `mod` prime) (prime^ex)
where
root' = (root + (inv*(prime-elim))*pp) `mod` (prime*pp)
diff' = root'*root' - n
(ex, nelim) = splitOff prime diff'
sqM2P :: Integer -> Word -> Maybe Integer
sqM2P n (wordToInt -> e)
| e < 2 = Just (n `mod` 2)
| n' == 0 = Just 0
| odd k = Nothing
| otherwise = fmap ((`mod` mdl) . (`shiftL` k2)) $ solve s e2
where
mdl = 1 `shiftL` e
n' = n `mod` mdl
(k, s) = shiftToOddCount n'
k2 = k `quot` 2
e2 = e - k
solve _ 1 = Just 1
solve 1 _ = Just 1
solve r _
| rem4 r == 3 = Nothing
| rem8 r == 5 = Nothing
| otherwise = fixup r (fst $ shiftToOddCount (r-1))
where
fixup x pw
| pw >= e2 = Just x
| otherwise = fixup x' pw'
where
x' = x + (1 `shiftL` (pw-1))
d = x'*x' - r
pw' = if d == 0 then e2 else fst (shiftToOddCount d)
rem4 :: Integral a => a -> Int
rem4 n = fromIntegral n .&. 3
rem8 :: Integral a => a -> Int
rem8 n = fromIntegral n .&. 7
findNonSquare :: Integer -> Integer
findNonSquare n
| rem8 n == 5 || rem8 n == 3 = 2
| otherwise = search primelist
where
primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67]
++ sieveFrom (68 + n `rem` 4)
search (p:ps) = case jacobi p n of
MinusOne -> p
_ -> search ps
search _ = error "Should never have happened, prime list exhausted."