module Math.NumberTheory.Moduli
(
jacobi
, invertMod
, powerMod
, powerModInteger
, chineseRemainder
, sqrtModP
, jacobi'
, powerMod'
, powerModInteger'
, sqrtModPList
, sqrtModP'
, tonelliShanks
, sqrtModPP
, sqrtModPPList
, sqrtModF
, sqrtModFList
, chineseRemainder2
) where
#include "MachDeps.h"
#if __GLASGOW_HASKELL__ < 709 || WORD_SIZE_IN_BITS == 32
import Data.Word
#endif
import Data.Bits
import Data.Array.Unboxed
import Data.List (nub)
import Control.Monad (foldM, liftM2)
import Math.NumberTheory.Utils (shiftToOddCount, splitOff)
import Math.NumberTheory.GCD (extendedGCD)
import Math.NumberTheory.Primes.Heap (sieveFrom)
import Math.NumberTheory.Unsafe
invertMod :: Integer -> Integer -> Maybe Integer
invertMod k m
| m <= 0 = error "Math.NumberTheory.Moduli.invertMod: non-positive modulus"
| otherwise = wrap $ go False 1 0 m k'
where
k' | r < 0 = r+m
| otherwise = r
where
r = k `rem` m
wrap x = case (x*k') `rem` m of
1 -> Just x
_ -> Nothing
go !b _ po _ 0 = if b then po else (mpo)
go b !pn po n d = case n `quotRem` d of
(q,r) -> go (not b) (q*pn+po) pn d r
jacobi :: (Integral a, Bits a) => a -> a -> Int
jacobi a b
| b < 0 = error "Math.NumberTheory.Moduli.jacobi: negative denominator"
| evenI b = error "Math.NumberTheory.Moduli.jacobi: even denominator"
| b == 1 = 1
| a == 0 = 0
| a == 1 = 1
| otherwise = jacobi' a b
jacobi' :: (Integral a, Bits a) => a -> a -> Int
jacobi' a b
| a == 0 = 0
| a == 1 = 1
| a < 0 = let n | rem4 b == 1 = 1
| otherwise = 1
(z,o) = shiftToOddCount (abs $ toInteger a)
s | evenI z || unsafeAt jac2 (rem8 b) == 1 = n
| otherwise = (n)
in s*jacobi' (fromInteger o) b
| a >= b = case a `rem` b of
0 -> 0
r -> jacPS 1 r b
| evenI a = case shiftToOddCount a of
(z,o) -> let r = 2 (rem4 o .&. rem4 b)
s | evenI z || unsafeAt jac2 (rem8 b) == 1 = r
| otherwise = (r)
in jacOL s b o
| otherwise = case rem4 a .&. rem4 b of
3 -> jacOL (1) b a
_ -> jacOL 1 b a
jacPS :: (Integral a, Bits a) => Int -> a -> a -> Int
jacPS !j a b
| evenI a = case shiftToOddCount a of
(z,o) | evenI z || unsafeAt jac2 (rem8 b) == 1 ->
jacOL (if rem4 o .&. rem4 b == 3 then (j) else j) b o
| otherwise ->
jacOL (if rem4 o .&. rem4 b == 3 then j else (j)) b o
| otherwise = jacOL (if rem4 a .&. rem4 b == 3 then (j) else j) b a
jacOL :: (Integral a, Bits a) => Int -> a -> a -> Int
jacOL !j a b
| b == 1 = j
| otherwise = case a `rem` b of
0 -> 0
r -> jacPS j r b
powerMod :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerMod = powerModImpl
powerModImpl :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerModImpl base expo md
| md <= 0 = error "Math.NumberTheory.Moduli.powerMod: non-positive modulus"
| md == 1 = 0
| expo == 0 = 1
| bse' == 1 = 1
| expo < 0 = case invertMod bse' md of
Just i -> powerMod'Impl i (negate expo) md
Nothing -> error "Math.NumberTheory.Moduli.powerMod: Base isn't invertible with respect to modulus"
| bse' == 0 = 0
| otherwise = powerMod'Impl bse' expo md
where
bse' = if base < 0 || md <= base then base `mod` md else base
powerMod' :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerMod' = powerMod'Impl
powerMod'Impl :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerMod'Impl base expo md = go expo 1 base
where
go 1 !a !s = (a*s) `rem` md
go e a s
| testBit e 0 = go (e `shiftR` 1) ((a*s) `rem` md) ((s*s) `rem` md)
| otherwise = go (e `shiftR` 1) a ((s*s) `rem` md)
powerModInteger :: Integer -> Integer -> Integer -> Integer
powerModInteger base ex mdl
| mdl <= 0 = error "Math.NumberTheory.Moduli.powerModInteger: non-positive modulus"
| mdl == 1 = 0
| ex == 0 = 1
| ex < 0 = case invertMod bse' mdl of
Just i -> powerModInteger' i (negate ex) mdl
Nothing -> error "Math.NumberTheory.Moduli.powerMod: Base isn't invertible with respect to modulus"
| bse' == 0 = 0
| bse' == 1 = 1
| otherwise = powerModInteger' bse' ex mdl
where
bse' = if base < 0 || mdl <= base then base `mod` mdl else base
powerModInteger' :: Integer -> Integer -> Integer -> Integer
powerModInteger' base expo md = go w1 1 base e1
where
w1 = fromInteger expo
e1 = expo `shiftR` 64
#if WORD_SIZE_IN_BITS == 32
go :: Word64 -> Integer -> Integer -> Integer -> Integer
go !w !a !s 0 = end a s w
go w a s e = inner1 a s 0
where
wl :: Word
!wl = fromIntegral w
wh :: Word
!wh = fromIntegral (w `shiftR` 32)
inner1 !au !sq 32 = inner2 au sq 0
inner1 au sq i
| testBit wl i = inner1 ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
| otherwise = inner1 au ((sq*sq) `rem` md) (i+1)
inner2 !au !sq 32 = go (fromInteger e) au sq (e `shiftR` 64)
inner2 au sq i
| testBit wh i = inner2 ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
| otherwise = inner2 au ((sq*sq) `rem` md) (i+1)
end !a !s w
| wh == 0 = fin a s wl
| otherwise = innerE a s 0
where
wl :: Word
!wl = fromIntegral w
wh :: Word
!wh = fromIntegral (w `shiftR` 32)
innerE !au !sq 32 = fin au sq wh
innerE au sq i
| testBit wl i = innerE ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
| otherwise = innerE au ((sq*sq) `rem` md) (i+1)
fin :: Integer -> Integer -> Word -> Integer
fin !a !s 1 = (a*s) `rem` md
fin a s w
| testBit w 0 = fin ((a*s) `rem` md) ((s*s) `rem` md) (w `shiftR` 1)
| otherwise = fin a ((s*s) `rem` md) (w `shiftR` 1)
#else
go :: Word -> Integer -> Integer -> Integer -> Integer
go !w !a !s 0 = end a s w
go w a s e = inner a s 0
where
inner !au !sq 64 = go (fromInteger e) au sq (e `shiftR` 64)
inner au sq i
| testBit w i = inner ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
| otherwise = inner au ((sq*sq) `rem` md) (i+1)
end !a !s 1 = (a*s) `rem` md
end a s w
| testBit w 0 = end ((a*s) `rem` md) ((s*s) `rem` md) (w `shiftR` 1)
| otherwise = end a ((s*s) `rem` md) (w `shiftR` 1)
#endif
sqrtModP :: Integer -> Integer -> Maybe Integer
sqrtModP n 2 = Just (n `mod` 2)
sqrtModP n prime = case jacobi' n prime of
0 -> Just 0
1 -> Just (sqrtModP' (n `mod` prime) prime)
_ -> Nothing
sqrtModPList :: Integer -> Integer -> [Integer]
sqrtModPList n prime
| prime == 2 = [n `mod` 2]
| otherwise = case sqrtModP n prime of
Just 0 -> [0]
Just r -> [r,primer]
_ -> []
sqrtModP' :: Integer -> Integer -> Integer
sqrtModP' square prime
| prime == 2 = square
| rem4 prime == 3 = powerModInteger' square ((prime + 1) `quot` 4) prime
| otherwise = tonelliShanks square prime
tonelliShanks :: Integer -> Integer -> Integer
tonelliShanks square prime = loop rc t1 generator log2
where
(log2,q) = shiftToOddCount (prime1)
nonSquare = findNonSquare prime
generator = powerModInteger' nonSquare q prime
rc = powerModInteger' square ((q+1) `quot` 2) prime
t1 = powerModInteger' square q prime
msqr x = (x*x) `rem` prime
msquare 0 x = x
msquare k x = msquare (k1) (msqr x)
findPeriod per 1 = per
findPeriod per x = findPeriod (per+1) (msqr x)
loop !r t c m
| t == 1 = r
| otherwise = loop nextR nextT nextC nextM
where
nextM = findPeriod 0 t
b = msquare (m 1 nextM) c
nextR = (r*b) `rem` prime
nextC = msqr b
nextT = (t*nextC) `rem` prime
sqrtModPP :: Integer -> (Integer,Int) -> Maybe Integer
sqrtModPP n (2,e) = sqM2P n e
sqrtModPP n (prime,expo) = case sqrtModP n prime of
Just r -> fixup r
_ -> Nothing
where
fixup r = let diff' = r*rn
in if diff' == 0
then Just r
else case splitOff prime diff' of
(e,q) | expo <= e -> Just r
| otherwise -> fmap (\inv -> hoist inv r (q `mod` prime) (prime^e)) (invertMod (2*r) prime)
hoist inv root elim pp
| diff' == 0 = root'
| expo <= ex = root'
| otherwise = hoist inv root' (nelim `mod` prime) (prime^ex)
where
root' = (root + (inv*(primeelim))*pp) `mod` (prime*pp)
diff' = root'*root' n
(ex, nelim) = splitOff prime diff'
sqM2P :: Integer -> Int -> Maybe Integer
sqM2P n e
| e < 2 = Just (n `mod` 2)
| n' == 0 = Just 0
| e <= k = Just 0
| odd k = Nothing
| otherwise = fmap ((`mod` mdl) . (`shiftL` k2)) $ solve s e2
where
mdl = 1 `shiftL` e
n' = n `mod` mdl
(k,s) = shiftToOddCount n'
k2 = k `quot` 2
e2 = ek
solve _ 1 = Just 1
solve 1 _ = Just 1
solve r p
| rem4 r == 3 = Nothing
| p == 2 = Just 1
| rem8 r == 5 = Nothing
| otherwise = fixup r (fst $ shiftToOddCount (r1))
where
fixup x pw
| pw >= e2 = Just x
| otherwise = fixup x' pw'
where
x' = x + (1 `shiftL` (pw1))
d = x'*x' r
pw' = if d == 0 then e2 else fst (shiftToOddCount d)
sqrtModF :: Integer -> [(Integer,Int)] -> Maybe Integer
sqrtModF _ [] = Nothing
sqrtModF n pps = do roots <- mapM (sqrtModPP n) pps
chineseRemainder $ zip roots (map (uncurry (^)) pps)
sqrtModFList :: Integer -> [(Integer,Int)] -> [Integer]
sqrtModFList _ [] = []
sqrtModFList n pps = map fst $ foldl1 (liftM2 comb) cs
where
ms :: [Integer]
ms = map (uncurry (^)) pps
rs :: [[Integer]]
rs = map (sqrtModPPList n) pps
cs :: [[(Integer,Integer)]]
cs = zipWith (\l m -> map (\x -> (x,m)) l) rs ms
comb t1@(_,m1) t2@(_,m2) = (chineseRemainder2 t1 t2,m1*m2)
sqrtModPPList :: Integer -> (Integer,Int) -> [Integer]
sqrtModPPList n (2,1) = [n `mod` 2]
sqrtModPPList n (2,expo)
= case sqM2P n expo of
Just r -> let m = 1 `shiftL` (expo1)
in nub [r, (r+m) `mod` (2*m), (mr) `mod` (2*m), 2*mr]
_ -> []
sqrtModPPList n pe@(prime,expo)
= case sqrtModPP n pe of
Just 0 -> [0]
Just r -> [prime^expo r, r]
_ -> []
chineseRemainder :: [(Integer,Integer)] -> Maybe Integer
chineseRemainder remainders = foldM addRem 0 remainders
where
!modulus = product (map snd remainders)
addRem acc (_,1) = Just acc
addRem acc (r,m) = do
let cf = modulus `quot` m
inv <- invertMod cf m
Just $! (acc + inv*cf*r) `mod` modulus
chineseRemainder2 :: (Integer,Integer) -> (Integer,Integer) -> Integer
chineseRemainder2 (r1, md1) (r2,md2)
= case extendedGCD md1 md2 of
(_,u,v) -> ((1 u*md1)*r1 + (1 v*md2)*r2) `mod` (md1*md2)
evenI :: Integral a => a -> Bool
evenI n = fromIntegral n .&. 1 == (0 :: Int)
rem4 :: Integral a => a -> Int
rem4 n = fromIntegral n .&. 3
rem8 :: Integral a => a -> Int
rem8 n = fromIntegral n .&. 7
jac2 :: UArray Int Int
jac2 = array (0,7) [(0,0),(1,1),(2,0),(3,1),(4,0),(5,1),(6,0),(7,1)]
findNonSquare :: Integer -> Integer
findNonSquare n
| rem8 n == 5 || rem8 n == 3 = 2
| otherwise = search primelist
where
primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67]
++ sieveFrom (68 + n `rem` 4)
search (p:ps)
| jacobi' p n == 1 = p
| otherwise = search ps
search _ = error "Should never have happened, prime list exhausted."