-- |
-- Module:      Math.NumberTheory.Moduli
-- Copyright:   (c) 2011 Daniel Fischer
-- Licence:     MIT
-- Maintainer:  Daniel Fischer <daniel.is.fischer@googlemail.com>
-- Stability:   Provisional
-- Portability: Non-portable (GHC extensions)
--
-- Miscellaneous functions related to modular arithmetic.
--
{-# LANGUAGE CPP, BangPatterns #-}
module Math.NumberTheory.Moduli
    ( -- * Functions with input check
      jacobi
    , invertMod
    , powerMod
    , powerModInteger
    , chineseRemainder
      -- ** Partially checked input
    , sqrtModP
      -- * Unchecked functions
    , jacobi'
    , powerMod'
    , powerModInteger'
    , sqrtModPList
    , sqrtModP'
    , tonelliShanks
    , sqrtModPP
    , sqrtModPPList
    , sqrtModF
    , sqrtModFList
    , chineseRemainder2
    ) where

#include "MachDeps.h"

import Data.Word
import Data.Bits
import Data.Array.Unboxed
import Data.Array.Base (unsafeAt)
import Data.Maybe (fromJust)
import Data.List (nub)
import Control.Monad (foldM, liftM2)

import Math.NumberTheory.Utils (shiftToOddCount, splitOff)
import Math.NumberTheory.GCD (extendedGCD)
import Math.NumberTheory.Primes.Heap (sieveFrom)
-- Guesstimated startup time for the Heap algorithm is lower than
-- the cost to sieve an entire chunk.

-- | Invert a number relative to a modulus.
--   If @number@ and @modulus@ are coprime, the result is
--   @Just inverse@ where
--
-- >    (number * inverse) `mod` (abs modulus) == 1
-- >    0 <= inverse < abs modulus
--
--   unless @modulus == 0@ and @abs number == 1@, in which case the
--   result is @Just number@.
--   If @gcd number modulus > 1@, the result is @Nothing@.
invertMod :: Integer -> Integer -> Maybe Integer
invertMod k 0 = if k == 1 || k == (-1) then Just k else Nothing
invertMod k m = wrap $ go False 1 0 m' k'
  where
    m' = abs m
    k' | r < 0     = r+m'
       | otherwise = r
         where
           r = k `rem` m'
    wrap x = case (x*k') `rem` m' of
               1 -> Just x
               _ -> Nothing
    -- Calculate modular inverse of k' modulo m' by continued fraction expansion
    -- of m'/k', say [a_0,a_1,...,a_s]. Let the convergents be p_j/q_j.
    -- Starting from j = -2, the arguments of go are
    -- (p_j/q_j) > m'/k', p_{j+1}, p_j, and n, d with n/d = [a_{j+2},...,a_s].
    -- Since m'/k' = p_s/q_s, and p_j*q_{j+1} - p_{j+1}*q_j = (-1)^(j+1), we have
    -- p_{s-1}*k' - q_{s-1}*m' = (-1)^s * gcd m' k', so if the inverse exists,
    -- it is either p_{s-1} or -p_{s-1}, depending on whether s is even or odd.
    go !b _ po _ 0 = if b then po else (m'-po)
    go b !pn po n d = case n `quotRem` d of
                        (q,r) -> go (not b) (q*pn+po) pn d r

-- | Jacobi symbol of two numbers.
--   The \"denominator\" must be odd and positive, this condition is checked.
--
--   If both numbers have a common prime factor, the result
--   is @0@, otherwise it is &#177;1.
{-# SPECIALISE jacobi :: Integer -> Integer -> Int,
                         Int -> Int -> Int,
                         Word -> Word -> Int
  #-}
jacobi :: (Integral a, Bits a) => a -> a -> Int
jacobi a b
  | b < 0       = error "Math.NumberTheory.Moduli.jacobi: negative denominator"
  | evenI b     = error "Math.NumberTheory.Moduli.jacobi: even denominator"
  | b == 1      = 1
  | a == 0      = 0
  | a == 1      = 1
  | otherwise   = jacobi' a b   -- b odd, > 1, a neither 0 or 1

-- Invariant: b > 1 and odd
-- | Jacobi symbol of two numbers without validity check of
--   the \"denominator\".
{-# SPECIALISE jacobi' :: Integer -> Integer -> Int,
                          Int -> Int -> Int,
                          Word -> Word -> Int
  #-}
jacobi' :: (Integral a, Bits a) => a -> a -> Int
jacobi' a b
  | a == 0      = 0
  | a == 1      = 1
  | a < 0       = let n | rem4 b == 1 = 1
                        | otherwise   = -1
                      -- Blech, minBound may pose problems
                      (z,o) = shiftToOddCount (abs $ toInteger a)
                      s | evenI z || unsafeAt jac2 (rem8 b) == 1 = n
                        | otherwise                              = (-n)
                  in s*jacobi' (fromInteger o) b
  | a >= b      = case a `rem` b of
                    0 -> 0
                    r -> jacPS 1 r b
  | evenI a     = case shiftToOddCount a of
                    (z,o) -> let r = 2 - (rem4 o .&. rem4 b)
                                 s | evenI z || unsafeAt jac2 (rem8 b) == 1 = r
                                   | otherwise                              = (-r)
                             in jacOL s b o
  | otherwise   = case rem4 a .&. rem4 b of
                    3 -> jacOL (-1) b a
                    _ -> jacOL 1 b a

-- numerator positive and smaller than denominator
{-# SPECIALISE jacPS :: Int -> Integer -> Integer -> Int,
                        Int -> Int -> Int -> Int,
                        Int -> Word -> Word -> Int
  #-}
jacPS :: (Integral a, Bits a) => Int -> a -> a -> Int
jacPS !j a b
  | evenI a     = case shiftToOddCount a of
                    (z,o) | evenI z || unsafeAt jac2 (rem8 b) == 1 ->
                              jacOL (if rem4 o .&. rem4 b == 3 then (-j) else j) b o
                          | otherwise ->
                              jacOL (if rem4 o .&. rem4 b == 3 then j else (-j)) b o
  | otherwise   = jacOL (if rem4 a .&. rem4 b == 3 then (-j) else j) b a

-- numerator odd, positive and larger than denominator
{-# SPECIALISE jacOL :: Int -> Integer -> Integer -> Int,
                        Int -> Int -> Int -> Int,
                        Int -> Word -> Word -> Int
  #-}
jacOL :: (Integral a, Bits a) => Int -> a -> a -> Int
jacOL !j a b
  | b == 1    = j
  | otherwise = case a `rem` b of
                 0 -> 0
                 r -> jacPS j r b

-- | Modular power.
--
-- > powerMod base exponent modulus
--
--   calculates @(base ^ exponent) \`mod\` modulus@ by repeated squaring and reduction.
--   If @exponent < 0@ and @base@ is invertible modulo @modulus@, @(inverse ^ |exponent|) \`mod\` modulus@
--   is calculated. This function does some input checking and sanitation before calling the unsafe worker.
{-# SPECIALISE powerMod :: Integer -> Int -> Integer -> Integer,
                           Integer -> Word -> Integer -> Integer
  #-}
{-# RULES
"powerMod/Integer" powerMod = powerModInteger
  #-}
powerMod :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerMod base expo md
  | md == 0     = base ^ expo
  | md' == 1    = 0
  | expo == 0   = 1
  | bse' == 1   = 1
  | expo < 0    = case invertMod bse' md' of
                    Just i  -> powerMod' i (negate expo) md'
                    Nothing -> error "Math.NumberTheory.Moduli.powerMod: Base isn't invertible with respect to modulus"
  | bse' == 0   = 0
  | otherwise   = powerMod' bse' expo md'
    where
      md' = abs md
      bse' = if base < 0 || md' <= base then base `mod` md' else base

-- | Modular power worker without input checking.
--   Assumes all arguments strictly positive and modulus greater than 1.
{-# SPECIALISE powerMod' :: Integer -> Int -> Integer -> Integer,
                            Integer -> Word -> Integer -> Integer
  #-}
{-# RULES
"powerMod'/Integer" powerMod' = powerModInteger'
  #-}
powerMod' :: (Integral a, Bits a) => Integer -> a -> Integer -> Integer
powerMod' base expo md = go expo 1 base
  where
    go 1 !a !s  = (a*s) `rem` md
    go e a s
      | testBit e 0 = go (e `shiftR` 1) ((a*s) `rem` md) ((s*s) `rem` md)
      | otherwise   = go (e `shiftR` 1) a ((s*s) `rem` md)

-- | Specialised version of 'powerMod' for 'Integer' exponents.
--   Reduces the number of shifts of the exponent since shifting
--   large 'Integer's is expensive. Call this function directly
--   if you don't want or can't rely on rewrite rules.
powerModInteger :: Integer -> Integer -> Integer -> Integer
powerModInteger base ex mdl
  | mdl == 0    = base ^ ex
  | mdl' == 1   = 0
  | ex == 0     = 1
  | ex < 0      = case invertMod bse' mdl' of
                    Just i  -> powerModInteger' i (negate ex) mdl'
                    Nothing -> error "Math.NumberTheory.Moduli.powerMod: Base isn't invertible with respect to modulus"
  | bse' == 0   = 0
  | bse' == 1   = 1
  | otherwise   = powerModInteger' bse' ex mdl'
    where
      mdl' = abs mdl
      bse' = if base < 0 || mdl' <= base then base `mod` mdl' else base

-- | Specialised worker without input checks. Makes the same assumptions
--   as the general version 'powerMod''.
powerModInteger' :: Integer -> Integer -> Integer -> Integer
powerModInteger' base expo md = go w1 1 base e1
  where
    w1 = fromInteger expo
    e1 = expo `shiftR` 64
#if WORD_SIZE_IN_BITS == 32
  -- Shifting large Integers is expensive, hence we reduce the
  -- number of shifts by processing in 64-bit chunks.
  -- On 32-bit systems, every testBit on a Word64 would be a C-call,
  -- thus it is faster to split each Word64 into the constituent 32-bit
  -- Words and process those separately.
  -- The code becomes ugly, unfortunately.
    go :: Word64 -> Integer -> Integer -> Integer -> Integer
    go !w !a !s 0  = end a s w
    go w a s e = inner1 a s 0
      where
        wl :: Word
        !wl = fromIntegral w
        wh :: Word
        !wh = fromIntegral (w `shiftR` 32)
        inner1 !au !sq 32 = inner2 au sq 0
        inner1 au sq i
          | testBit wl i = inner1 ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
          | otherwise    = inner1 au ((sq*sq) `rem` md) (i+1)
        inner2 !au !sq 32 = go (fromInteger e) au sq (e `shiftR` 64)
        inner2 au sq i
          | testBit wh i = inner2 ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
          | otherwise    = inner2 au ((sq*sq) `rem` md) (i+1)
    end !a !s w
      | wh == 0   = fin a s wl
      | otherwise = innerE a s 0
        where
          wl :: Word
          !wl = fromIntegral w
          wh :: Word
          !wh = fromIntegral (w `shiftR` 32)
          innerE !au !sq 32 = fin au sq wh
          innerE au sq i
            | testBit wl i = innerE ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
            | otherwise    = innerE au ((sq*sq) `rem` md) (i+1)
    fin :: Integer -> Integer -> Word -> Integer
    fin !a !s 1 = (a*s) `rem` md
    fin a s w
      | testBit w 0 = fin ((a*s) `rem` md) ((s*s) `rem` md) (w `shiftR` 1)
      | otherwise   = fin a ((s*s) `rem` md) (w `shiftR` 1)

#else
  -- WORD_SIZE_IN_BITS == 64, otherwise things wouldn't compile anyway
  -- Shorter code since we need not split each 64-bit word.
    go :: Word -> Integer -> Integer -> Integer -> Integer
    go !w !a !s 0  = end a s w
    go w a s e = inner a s 0
      where
        inner !au !sq 64 = go (fromInteger e) au sq (e `shiftR` 64)
        inner au sq i
          | testBit w i = inner ((au*sq) `rem` md) ((sq*sq) `rem` md) (i+1)
          | otherwise   = inner au ((sq*sq) `rem` md) (i+1)
    end !a !s 1 = (a*s) `rem` md
    end a s w
      | testBit w 0 = end ((a*s) `rem` md) ((s*s) `rem` md) (w `shiftR` 1)
      | otherwise   = end a ((s*s) `rem` md) (w `shiftR` 1)

#endif

-- | @sqrtModP n prime@ calculates a modular square root of @n@ modulo @prime@
--   if that exists. The second argument /must/ be a (positive) prime, otherwise
--   the computation may not terminate and if it does, may yield a wrong result.
--   The precondition is /not/ checked.
--
--   If @prime@ is a prime and @n@ a quadratic residue modulo @prime@, the result
--   is @Just r@ where @r^2 ≡ n (mod prime)@, if @n@ is a quadratic nonresidue,
--   the result is @Nothing@.
sqrtModP :: Integer -> Integer -> Maybe Integer
sqrtModP n 2 = Just (n `mod` 2)
sqrtModP n prime = case jacobi' n prime of
                     0 -> Just 0
                     1 -> Just (tonelliShanks (n `mod` prime) prime)
                     _ -> Nothing

-- | @sqrtModPList n prime@ computes the list of all square roots of @n@
--   modulo @prime@. @prime@ /must/ be a (positive) prime.
--   The precondition is /not/ checked.
sqrtModPList :: Integer -> Integer -> [Integer]
sqrtModPList n prime
    | prime == 2    = [n `mod` 2]
    | otherwise     = case sqrtModP n prime of
                        Just 0 -> [0]
                        Just r -> [r,prime-r] -- The group of units in Z/(p) is cyclic
                        _      -> []

-- | @sqrtModP' square prime@ finds a square root of @square@ modulo
--   prime. @prime@ /must/ be a (positive) prime, and @sqaure@ /must/ be a
--   quadratic residue modulo @prime@, i.e. @'jacobi square prime == 1@.
--   The precondition is /not/ checked.
sqrtModP' :: Integer -> Integer -> Integer
sqrtModP' square prime
    | prime == 2    = square
    | rem4 prime == 3 = powerModInteger' square ((prime + 1) `quot` 4) prime
    | otherwise     = tonelliShanks square prime

-- | @tonelliShanks square prime@ calculates a square root of @square@
--   modulo @prime@, where @prime@ is a prime of the form @4*k + 1@ and
--   @square@ is a quadratic residue modulo @prime@, using the
--   Tonelli-Shanks algorithm.
--   No checks on the input are performed.
tonelliShanks :: Integer -> Integer -> Integer
tonelliShanks square prime = loop rc t1 generator log2
  where
    (log2,q) = shiftToOddCount (prime-1)
    nonSquare = findNonSquare prime
    generator = powerModInteger' nonSquare q prime
    rc = powerModInteger' square ((q+1) `quot` 2) prime
    t1 = powerModInteger' square q prime
    msqr x = (x*x) `rem` prime
    msquare 0 x = x
    msquare k x = msquare (k-1) (msqr x)
    findPeriod per 1 = per
    findPeriod per x = findPeriod (per+1) (msqr x)
    loop !r t c m
        | t == 1    = r
        | otherwise = loop nextR nextT nextC nextM
          where
            nextM = findPeriod 0 t
            b     = msquare (m - 1 - nextM) c
            nextR = (r*b) `rem` prime
            nextC = msqr b
            nextT = (t*nextC) `rem` prime

-- | @sqrtModPP n (prime,expo)@ calculates a square root of @n@
--   modulo @prime^expo@ if one exists. @prime@ /must/ be a
--   (positive) prime. @expo@ must be positive, @n@ must be coprime
--   to @prime@
sqrtModPP :: Integer -> (Integer,Int) -> Maybe Integer
sqrtModPP n (2,e) = sqM2P n e
sqrtModPP n (prime,expo) = case sqrtModP n prime of
                             Just r -> Just $ fixup r
                             _      -> Nothing
  where
    fixup r = case splitOff prime (r*r-n) of
                (e,q) | expo <= e -> r
                      | otherwise -> hoist (fromJust $ invertMod (2*r) prime) r (q `mod` prime) (prime^e)
                      --
    hoist inv root elim pp
        | expo <= ex    = root'
        | otherwise     = hoist inv root' (nelim `mod` prime) (prime^ex)
          where
            root' = (root + (inv*(prime-elim))*pp) `mod` (prime*pp)
            (ex, nelim) = splitOff prime (root'*root' - n)

-- dirty, dirty
sqM2P :: Integer -> Int -> Maybe Integer
sqM2P n e
    | e < 2     = Just (n `mod` 2)
    | n' == 0   = Just 0
    | e <= k    = Just 0
    | odd k     = Nothing
    | otherwise = fmap ((`mod` mdl) . (`shiftL` k2)) $ solve s e2
      where
        mdl = 1 `shiftL` e
        n' = n `mod` mdl
        (k,s) = shiftToOddCount n'
        k2 = k `quot` 2
        e2 = e-k
        solve _ 1 = Just 1
        solve 1 _ = Just 1
        solve r p
            | rem4 r == 3   = Nothing  -- otherwise r ≡ 1 (mod 4)
            | p == 2        = Just 1   -- otherwise p >= 3
            | rem8 r == 5   = Nothing  -- otherwise r ≡ 1 (mod 8)
            | otherwise     = fixup r (fst $ shiftToOddCount (r-1))
              where
                fixup x pw
                    | pw >= e2  = Just x
                    | otherwise = fixup x' pw'
                      where
                        x' = x + (1 `shiftL` (pw-1))
                        d = x'*x' - r
                        pw' = if d == 0 then e2 else fst (shiftToOddCount d)

-- | @sqrtModF n primePowers@ calculates a square root of @n@ modulo
--   @product [p^k | (p,k) <- primePowers]@ if one exists and all primes
--   are distinct.
sqrtModF :: Integer -> [(Integer,Int)] -> Maybe Integer
sqrtModF n pps = do roots <- mapM (sqrtModPP n) pps
                    chineseRemainder $ zip roots (map (uncurry (^)) pps)

-- | @sqrtModFList n primePowers@ calculates all square roots of @n@ modulo
--   @product [p^k | (p,k) <- primePowers]@ if all primes are distinct.
sqrtModFList :: Integer -> [(Integer,Int)] -> [Integer]
sqrtModFList n pps = map fst $ foldl1 (liftM2 comb) cs
  where
    ms :: [Integer]
    ms = map (uncurry (^)) pps
    rs :: [[Integer]]
    rs = map (sqrtModPPList n) pps
    cs :: [[(Integer,Integer)]]
    cs = zipWith (\l m -> map (\x -> (x,m)) l) rs ms
    comb t1@(_,m1) t2@(_,m2) = (chineseRemainder2 t1 t2,m1*m2)

-- | @sqrtModPPList n (prime,expo)@ calculates the list of all
--   square roots of @n@ modulo @prime^expo@. The same restriction
--   as in 'sqrtModPP' applies to the arguments.
sqrtModPPList :: Integer -> (Integer,Int) -> [Integer]
sqrtModPPList n (2,expo)
    = case sqM2P n expo of
        Just r -> let m = 1 `shiftL` (expo-1)
                  in nub [r, (r+m) `mod` (2*m), (m-r) `mod` (2*m), 2*m-r]
        _ -> []
sqrtModPPList n pe@(prime,expo)
    = case sqrtModPP n pe of
        Just 0 -> [0]
        Just r -> [prime^expo - r, r] -- The group of units in Z/(p^e) is cyclic
        _      -> []

-- | Given a list @[(r_1,m_1), ..., (r_n,m_n)]@ of @(residue,modulus)@
--   pairs, @chineseRemainder@ calculates the solution to the simultaneous
--   congruences
--
-- >
-- > r ≡ r_k (mod m_k)
-- >
--
--   if all moduli are pairwise coprime. If not all moduli are
--   pairwise coprime, the result is @Nothing@ regardless of whether
--   a solution exists.
chineseRemainder :: [(Integer,Integer)] -> Maybe Integer
chineseRemainder remainders = foldM addRem 0 remainders
  where
    !modulus = product (map snd remainders)
    addRem acc (r,m) = do
        let cf = modulus `quot` m
        inv <- invertMod cf m
        Just $! (acc + inv*cf*r) `mod` modulus

-- | @chineseRemainder2 (r_1,m_1) (r_2,m_2)@ calculates the solution of
--
-- >
-- > r ≡ r_k (mod m_k)
--
--   if @m_1@ and @m_2@ are coprime.
chineseRemainder2 :: (Integer,Integer) -> (Integer,Integer) -> Integer
chineseRemainder2 (r1, md1) (r2,md2)
    = case extendedGCD md1 md2 of
        (_,u,v) -> ((1 - u*md1)*r1 + (1 - v*md2)*r2) `mod` (md1*md2)

-- Utilities

-- For large Integers, going via Int is much faster than bit-fiddling
-- on the Integer, so we do that.
{-# SPECIALISE evenI :: Integer -> Bool,
                        Int -> Bool,
                        Word -> Bool
  #-}
evenI :: (Integral a, Bits a) => a -> Bool
evenI n = fromIntegral n .&. 1 == (0 :: Int)

{-# SPECIALISE rem4 :: Integer -> Int,
                       Int -> Int,
                       Word -> Int
  #-}
rem4 :: (Integral a, Bits a) => a -> Int
rem4 n = fromIntegral n .&. 3

{-# SPECIALISE rem8 :: Integer -> Int,
                       Int -> Int,
                       Word -> Int
  #-}
rem8 :: (Integral a, Bits a) => a -> Int
rem8 n = fromIntegral n .&. 7

jac2 :: UArray Int Int
jac2 = array (0,7) [(0,0),(1,1),(2,0),(3,-1),(4,0),(5,-1),(6,0),(7,1)]

findNonSquare :: Integer -> Integer
findNonSquare n
    | rem8 n == 5 || rem8 n == 3  = 2
    | otherwise = search primelist
      where
        primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67]
                        ++ sieveFrom (68 + n `rem` 4) -- prevent sharing
        search (p:ps)
            | jacobi' p n == -1 = p
            | otherwise         = search ps
        search _ = error "Should never have happened, prime list exhausted."