module Math.NumberTheory.Moduli
(
jacobi
, invertMod
, powerMod
, powerModInteger
, chineseRemainder
, sqrtModP
, jacobi'
, powerMod'
, powerModInteger'
, sqrtModPList
, sqrtModP'
, tonelliShanks
, sqrtModPP
, sqrtModPPList
, sqrtModF
, sqrtModFList
, chineseRemainder2
) where
#include "MachDeps.h"
import Data.Word
import Data.Bits
import Data.Array.Unboxed
import Data.Array.Base (unsafeAt)
import Data.Maybe (fromJust)
import Data.List (nub)
import Control.Monad (foldM, liftM2)
import Math.NumberTheory.Utils (shiftToOddCount, splitOff)
import Math.NumberTheory.GCD (extendedGCD)
import Math.NumberTheory.Primes.Heap (sieveFrom)
invertMod :: Integer -> Integer -> Maybe Integer
invertMod k 0 = if k == 1 || k == (1) then Just k else Nothing
invertMod k m = wrap $ go False 1 0 m' k'
where
m' = abs m
k' | r < 0 = r+m'
| otherwise = r
where
r = k `rem` m'
wrap x = case (x*k') `rem` m' of
1 -> Just x
_ -> Nothing
go !b _ po _ 0 = if b then po else (m'po)
go b !pn po n d = case n `quotRem` d of
(q,r) -> go (not b) (q*pn+po) pn d r
jacobi :: (Integral a, Bits a) => a -> a -> Int
jacobi a b
| b < 0 = error "Math.NumberTheory.Moduli.jacobi: negative denominator"
| evenI b = error "Math.NumberTheory.Moduli.jacobi: even denominator"
| b == 1 = 1
| a == 0 = 0
| a == 1 = 1
| otherwise = jacobi' a b
jacobi' :: (Integral a, Bits a) => a -> a -> Int
jacobi' a b
| a == 0 = 0
| a == 1 = 1
| a < 0 = let n | rem4 b == 1 = 1
| otherwise = 1
(z,o) = shiftToOddCount (abs $ toInteger a)
s | evenI z || unsafeAt jac2 (rem8 b) == 1 = n
| otherwise = (n)
in s*jacobi' (fromInteger o) b
| a >= b = case a `rem` b of
0 -> 0
r -> jacPS 1 r b
| evenI a = case shiftToOddCount a of
(z,o) -> let r = 2 (rem4 o .&. rem4 b)
s | evenI z || unsafeAt jac2 (rem8 b) == 1 = r
| otherwise = (r)
in jacOL s b o
| otherwise = case rem4 a .&. rem4 b of
3 -> jacOL (1) b a
_ -> jacOL 1 b a
jacPS :: (Integral a, Bits a) => Int -> a -> a -> Int
jacPS !j a b
| evenI a = case shiftToOddCount a of
(z,o) | evenI z || unsafeAt jac2 (rem8 b) == 1 ->
jacOL (if rem4 o .&. rem4 b == 3 then (j) else j) b o
| otherwise ->
jacOL (if rem4 o .&. rem4 b == 3 then j else (j)) b o
| otherwise = jacOL (if rem4 a .&. rem4 b == 3 then (j) else j) b a
jacOL :: (Integral a, Bits a) => Int -> a -> a -> Int
jacOL !j a b
| b == 1 = j
| otherwise = case a `rem` b of
0 -> 0
r -> jacPS j r b
powerMod :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerMod = powerModImpl
powerModImpl :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerModImpl base expo md
| md == 0 = base ^ expo
| md' == 1 = 0
| expo == 0 = 1
| bse' == 1 = 1
| expo < 0 = case invertMod bse' md' of
Just i -> powerMod'Impl i (negate expo) md'
Nothing -> error "Math.NumberTheory.Moduli.powerMod: Base isn't invertible with respect to modulus"
| bse' == 0 = 0
| otherwise = powerMod'Impl bse' expo md'
where
md' = abs md
bse' = if base < 0 || md' <= base then base `mod` md' else base
powerMod' :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerMod' = powerMod'Impl
powerMod'Impl :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerMod'Impl base expo md = go expo 1 base
where
go 1 !a !s = (a*s) `rem` md
go e a s
| testBit e 0 = go (e `shiftR` 1) ((a*s) `rem` md) ((s*s) `rem` md)
| otherwise = go (e `shiftR` 1) a ((s*s) `rem` md)
powerModInteger :: Integer -> Integer -> Integer -> Integer
powerModInteger base ex mdl
| mdl == 0 = base ^ ex
| mdl' == 1 = 0
| ex == 0 = 1
| ex < 0 = case invertMod bse' mdl' of
Just i -> powerModInteger' i (negate ex) mdl'
Nothing -> error "Math.NumberTheory.Moduli.powerMod: Base isn't invertible with respect to modulus"
| bse' == 0 = 0
| bse' == 1 = 1
| otherwise = powerModInteger' bse' ex mdl'
where
mdl' = abs mdl
bse' = if base < 0 || mdl' <= base then base `mod` mdl' else base
powerModInteger' :: Integer -> Integer -> Integer -> Integer
powerModInteger' base expo md = go w1 1 base e1
where
w1 = fromInteger expo
e1 = expo `shiftR` 64
#if WORD_SIZE_IN_BITS == 32
go :: Word64 -> Integer -> Integer -> Integer -> Integer
go !w !a !s 0 = end a s w
go w a s e = inner1 a s 0
where
wl :: Word
!wl = fromIntegral w
wh :: Word
!wh = fromIntegral (w `shiftR` 32)
inner1 !au !sq 32 = inner2 au sq 0
inner1 au sq i
| testBit wl i = inner1 ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
| otherwise = inner1 au ((sq*sq) `rem` md) (i+1)
inner2 !au !sq 32 = go (fromInteger e) au sq (e `shiftR` 64)
inner2 au sq i
| testBit wh i = inner2 ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
| otherwise = inner2 au ((sq*sq) `rem` md) (i+1)
end !a !s w
| wh == 0 = fin a s wl
| otherwise = innerE a s 0
where
wl :: Word
!wl = fromIntegral w
wh :: Word
!wh = fromIntegral (w `shiftR` 32)
innerE !au !sq 32 = fin au sq wh
innerE au sq i
| testBit wl i = innerE ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
| otherwise = innerE au ((sq*sq) `rem` md) (i+1)
fin :: Integer -> Integer -> Word -> Integer
fin !a !s 1 = (a*s) `rem` md
fin a s w
| testBit w 0 = fin ((a*s) `rem` md) ((s*s) `rem` md) (w `shiftR` 1)
| otherwise = fin a ((s*s) `rem` md) (w `shiftR` 1)
#else
go :: Word -> Integer -> Integer -> Integer -> Integer
go !w !a !s 0 = end a s w
go w a s e = inner a s 0
where
inner !au !sq 64 = go (fromInteger e) au sq (e `shiftR` 64)
inner au sq i
| testBit w i = inner ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
| otherwise = inner au ((sq*sq) `rem` md) (i+1)
end !a !s 1 = (a*s) `rem` md
end a s w
| testBit w 0 = end ((a*s) `rem` md) ((s*s) `rem` md) (w `shiftR` 1)
| otherwise = end a ((s*s) `rem` md) (w `shiftR` 1)
#endif
sqrtModP :: Integer -> Integer -> Maybe Integer
sqrtModP n 2 = Just (n `mod` 2)
sqrtModP n prime = case jacobi' n prime of
0 -> Just 0
1 -> Just (sqrtModP' (n `mod` prime) prime)
_ -> Nothing
sqrtModPList :: Integer -> Integer -> [Integer]
sqrtModPList n prime
| prime == 2 = [n `mod` 2]
| otherwise = case sqrtModP n prime of
Just 0 -> [0]
Just r -> [r,primer]
_ -> []
sqrtModP' :: Integer -> Integer -> Integer
sqrtModP' square prime
| prime == 2 = square
| rem4 prime == 3 = powerModInteger' square ((prime + 1) `quot` 4) prime
| otherwise = tonelliShanks square prime
tonelliShanks :: Integer -> Integer -> Integer
tonelliShanks square prime = loop rc t1 generator log2
where
(log2,q) = shiftToOddCount (prime1)
nonSquare = findNonSquare prime
generator = powerModInteger' nonSquare q prime
rc = powerModInteger' square ((q+1) `quot` 2) prime
t1 = powerModInteger' square q prime
msqr x = (x*x) `rem` prime
msquare 0 x = x
msquare k x = msquare (k1) (msqr x)
findPeriod per 1 = per
findPeriod per x = findPeriod (per+1) (msqr x)
loop !r t c m
| t == 1 = r
| otherwise = loop nextR nextT nextC nextM
where
nextM = findPeriod 0 t
b = msquare (m 1 nextM) c
nextR = (r*b) `rem` prime
nextC = msqr b
nextT = (t*nextC) `rem` prime
sqrtModPP :: Integer -> (Integer,Int) -> Maybe Integer
sqrtModPP n (2,e) = sqM2P n e
sqrtModPP n (prime,expo) = case sqrtModP n prime of
Just r -> Just $ fixup r
_ -> Nothing
where
fixup r = let diff' = r*rn
in if diff' == 0
then r
else case splitOff prime diff' of
(e,q) | expo <= e -> r
| otherwise -> hoist (fromJust $ invertMod (2*r) prime) r (q `mod` prime) (prime^e)
hoist inv root elim pp
| diff' == 0 = root'
| expo <= ex = root'
| otherwise = hoist inv root' (nelim `mod` prime) (prime^ex)
where
root' = (root + (inv*(primeelim))*pp) `mod` (prime*pp)
diff' = root'*root' n
(ex, nelim) = splitOff prime diff'
sqM2P :: Integer -> Int -> Maybe Integer
sqM2P n e
| e < 2 = Just (n `mod` 2)
| n' == 0 = Just 0
| e <= k = Just 0
| odd k = Nothing
| otherwise = fmap ((`mod` mdl) . (`shiftL` k2)) $ solve s e2
where
mdl = 1 `shiftL` e
n' = n `mod` mdl
(k,s) = shiftToOddCount n'
k2 = k `quot` 2
e2 = ek
solve _ 1 = Just 1
solve 1 _ = Just 1
solve r p
| rem4 r == 3 = Nothing
| p == 2 = Just 1
| rem8 r == 5 = Nothing
| otherwise = fixup r (fst $ shiftToOddCount (r1))
where
fixup x pw
| pw >= e2 = Just x
| otherwise = fixup x' pw'
where
x' = x + (1 `shiftL` (pw1))
d = x'*x' r
pw' = if d == 0 then e2 else fst (shiftToOddCount d)
sqrtModF :: Integer -> [(Integer,Int)] -> Maybe Integer
sqrtModF n pps = do roots <- mapM (sqrtModPP n) pps
chineseRemainder $ zip roots (map (uncurry (^)) pps)
sqrtModFList :: Integer -> [(Integer,Int)] -> [Integer]
sqrtModFList n pps = map fst $ foldl1 (liftM2 comb) cs
where
ms :: [Integer]
ms = map (uncurry (^)) pps
rs :: [[Integer]]
rs = map (sqrtModPPList n) pps
cs :: [[(Integer,Integer)]]
cs = zipWith (\l m -> map (\x -> (x,m)) l) rs ms
comb t1@(_,m1) t2@(_,m2) = (chineseRemainder2 t1 t2,m1*m2)
sqrtModPPList :: Integer -> (Integer,Int) -> [Integer]
sqrtModPPList n (2,expo)
= case sqM2P n expo of
Just r -> let m = 1 `shiftL` (expo1)
in nub [r, (r+m) `mod` (2*m), (mr) `mod` (2*m), 2*mr]
_ -> []
sqrtModPPList n pe@(prime,expo)
= case sqrtModPP n pe of
Just 0 -> [0]
Just r -> [prime^expo r, r]
_ -> []
chineseRemainder :: [(Integer,Integer)] -> Maybe Integer
chineseRemainder remainders = foldM addRem 0 remainders
where
!modulus = product (map snd remainders)
addRem acc (r,m) = do
let cf = modulus `quot` m
inv <- invertMod cf m
Just $! (acc + inv*cf*r) `mod` modulus
chineseRemainder2 :: (Integer,Integer) -> (Integer,Integer) -> Integer
chineseRemainder2 (r1, md1) (r2,md2)
= case extendedGCD md1 md2 of
(_,u,v) -> ((1 u*md1)*r1 + (1 v*md2)*r2) `mod` (md1*md2)
evenI :: (Integral a, Bits a) => a -> Bool
evenI n = fromIntegral n .&. 1 == (0 :: Int)
rem4 :: (Integral a, Bits a) => a -> Int
rem4 n = fromIntegral n .&. 3
rem8 :: (Integral a, Bits a) => a -> Int
rem8 n = fromIntegral n .&. 7
jac2 :: UArray Int Int
jac2 = array (0,7) [(0,0),(1,1),(2,0),(3,1),(4,0),(5,1),(6,0),(7,1)]
findNonSquare :: Integer -> Integer
findNonSquare n
| rem8 n == 5 || rem8 n == 3 = 2
| otherwise = search primelist
where
primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67]
++ sieveFrom (68 + n `rem` 4)
search (p:ps)
| jacobi' p n == 1 = p
| otherwise = search ps
search _ = error "Should never have happened, prime list exhausted."