-- |
-- Module:      Math.NumberTheory.Moduli.Equations
-- Copyright:   (c) 2018 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Polynomial modular equations.
--

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns        #-}

module Math.NumberTheory.Moduli.Equations
  ( solveLinear
  , solveQuadratic
  ) where

import Data.Constraint
import Data.Maybe
import Data.Mod
import GHC.Integer.GMP.Internals
import GHC.TypeNats (KnownNat, natVal)

import Math.NumberTheory.Moduli.Chinese
import Math.NumberTheory.Moduli.Singleton
import Math.NumberTheory.Moduli.Sqrt
import Math.NumberTheory.Primes
import Math.NumberTheory.Utils (recipMod)

-------------------------------------------------------------------------------
-- Linear equations

-- | Find all solutions of ax + b ≡ 0 (mod m).
--
-- >>> :set -XDataKinds
-- >>> solveLinear (6 :: Mod 10) 4 -- solving 6x + 4 ≡ 0 (mod 10)
-- [(1 `modulo` 10),(6 `modulo` 10)]
solveLinear
  :: KnownNat m
  => Mod m   -- ^ a
  -> Mod m   -- ^ b
  -> [Mod m] -- ^ list of x
solveLinear :: Mod m -> Mod m -> [Mod m]
solveLinear Mod m
a Mod m
b = (Integer -> Mod m) -> [Integer] -> [Mod m]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> Mod m
forall a. Num a => Integer -> a
fromInteger ([Integer] -> [Mod m]) -> [Integer] -> [Mod m]
forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer -> [Integer]
solveLinear' (Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Mod m -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal Mod m
a)) (Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Mod m -> Natural
forall (m :: Nat). Mod m -> Natural
unMod Mod m
a)) (Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Mod m -> Natural
forall (m :: Nat). Mod m -> Natural
unMod Mod m
b))

solveLinear' :: Integer -> Integer -> Integer -> [Integer]
solveLinear' :: Integer -> Integer -> Integer -> [Integer]
solveLinear' Integer
m Integer
a Integer
b = case Integer -> Integer -> Integer -> Maybe Integer
solveLinearCoprime Integer
m' (Integer
a Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
d) (Integer
b Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
d) of
  Maybe Integer
Nothing -> []
  Just Integer
x  -> (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
i -> Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
m' Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
i) [Integer
0 .. Integer
d Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1]
  where
    d :: Integer
d = Integer
m Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`gcd` Integer
a Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`gcd` Integer
b
    m' :: Integer
m' = Integer
m Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
d

solveLinearCoprime :: Integer -> Integer -> Integer -> Maybe Integer
solveLinearCoprime :: Integer -> Integer -> Integer -> Maybe Integer
solveLinearCoprime Integer
1 Integer
_ Integer
_ = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
0
solveLinearCoprime Integer
m Integer
a Integer
b = (\Integer
a1 -> Integer -> Integer
forall a. Num a => a -> a
negate Integer
b Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
a1 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
m) (Integer -> Integer) -> Maybe Integer -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer -> Integer -> Maybe Integer
recipMod Integer
a Integer
m

-------------------------------------------------------------------------------
-- Quadratic equations

-- | Find all solutions of ax² + bx + c ≡ 0 (mod m).
--
-- >>> :set -XDataKinds
-- >>> solveQuadratic sfactors (1 :: Mod 32) 0 (-17) -- solving x² - 17 ≡ 0 (mod 32)
-- [(9 `modulo` 32),(25 `modulo` 32),(7 `modulo` 32),(23 `modulo` 32)]
solveQuadratic
  :: SFactors Integer m
  -> Mod m   -- ^ a
  -> Mod m   -- ^ b
  -> Mod m   -- ^ c
  -> [Mod m] -- ^ list of x
solveQuadratic :: SFactors Integer m -> Mod m -> Mod m -> Mod m -> [Mod m]
solveQuadratic SFactors Integer m
sm Mod m
a Mod m
b Mod m
c = case SFactors Integer m -> (() :: Constraint) :- KnownNat m
forall a (m :: Nat).
Integral a =>
SFactors a m -> (() :: Constraint) :- KnownNat m
proofFromSFactors SFactors Integer m
sm of
  Sub (() :: Constraint) => Dict (KnownNat m)
Dict ->
    (Integer -> Mod m) -> [Integer] -> [Mod m]
forall a b. (a -> b) -> [a] -> [b]
map Integer -> Mod m
forall a. Num a => Integer -> a
fromInteger
    ([Integer] -> [Mod m]) -> [Integer] -> [Mod m]
forall a b. (a -> b) -> a -> b
$ ([Integer], Integer) -> [Integer]
forall a b. (a, b) -> a
fst
    (([Integer], Integer) -> [Integer])
-> ([Integer], Integer) -> [Integer]
forall a b. (a -> b) -> a -> b
$ [([Integer], Integer)] -> ([Integer], Integer)
combine
    ([([Integer], Integer)] -> ([Integer], Integer))
-> [([Integer], Integer)] -> ([Integer], Integer)
forall a b. (a -> b) -> a -> b
$ ((Prime Integer, Word) -> ([Integer], Integer))
-> [(Prime Integer, Word)] -> [([Integer], Integer)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Prime Integer
p, Word
n) -> (Integer -> Integer -> Integer -> Prime Integer -> Word -> [Integer]
solveQuadraticPrimePower Integer
a' Integer
b' Integer
c' Prime Integer
p Word
n, Prime Integer -> Integer
forall a. Prime a -> a
unPrime Prime Integer
p Integer -> Word -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ Word
n))
    ([(Prime Integer, Word)] -> [([Integer], Integer)])
-> [(Prime Integer, Word)] -> [([Integer], Integer)]
forall a b. (a -> b) -> a -> b
$ SFactors Integer m -> [(Prime Integer, Word)]
forall a (m :: Nat). SFactors a m -> [(Prime a, Word)]
unSFactors SFactors Integer m
sm
  where
    a' :: Integer
a' = Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Natural -> Integer) -> Natural -> Integer
forall a b. (a -> b) -> a -> b
$ Mod m -> Natural
forall (m :: Nat). Mod m -> Natural
unMod Mod m
a
    b' :: Integer
b' = Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Natural -> Integer) -> Natural -> Integer
forall a b. (a -> b) -> a -> b
$ Mod m -> Natural
forall (m :: Nat). Mod m -> Natural
unMod Mod m
b
    c' :: Integer
c' = Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Natural -> Integer) -> Natural -> Integer
forall a b. (a -> b) -> a -> b
$ Mod m -> Natural
forall (m :: Nat). Mod m -> Natural
unMod Mod m
c

    combine :: [([Integer], Integer)] -> ([Integer], Integer)
    combine :: [([Integer], Integer)] -> ([Integer], Integer)
combine = (([Integer], Integer)
 -> ([Integer], Integer) -> ([Integer], Integer))
-> ([Integer], Integer)
-> [([Integer], Integer)]
-> ([Integer], Integer)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
      (\([Integer]
xs, Integer
xm) ([Integer]
ys, Integer
ym) -> ([ (Integer, Integer) -> Integer
forall a b. (a, b) -> a
fst ((Integer, Integer) -> Integer) -> (Integer, Integer) -> Integer
forall a b. (a -> b) -> a -> b
$ Maybe (Integer, Integer) -> (Integer, Integer)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Integer, Integer) -> (Integer, Integer))
-> Maybe (Integer, Integer) -> (Integer, Integer)
forall a b. (a -> b) -> a -> b
$ (Integer, Integer)
-> (Integer, Integer) -> Maybe (Integer, Integer)
forall a.
(Eq a, Ring a, Euclidean a) =>
(a, a) -> (a, a) -> Maybe (a, a)
chinese (Integer
x, Integer
xm) (Integer
y, Integer
ym) | Integer
x <- [Integer]
xs, Integer
y <- [Integer]
ys ], Integer
xm Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
ym))
      ([Integer
0], Integer
1)

solveQuadraticPrimePower
  :: Integer
  -> Integer
  -> Integer
  -> Prime Integer
  -> Word
  -> [Integer]
solveQuadraticPrimePower :: Integer -> Integer -> Integer -> Prime Integer -> Word -> [Integer]
solveQuadraticPrimePower Integer
a Integer
b Integer
c Prime Integer
p = Word -> [Integer]
go
  where
    go :: Word -> [Integer]
    go :: Word -> [Integer]
go Word
0 = [Integer
0]
    go Word
1 = Integer -> Integer -> Integer -> Prime Integer -> [Integer]
solveQuadraticPrime Integer
a Integer
b Integer
c Prime Integer
p
    go Word
k = (Integer -> [Integer]) -> [Integer] -> [Integer]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Word -> Integer -> [Integer]
liftRoot Word
k) (Word -> [Integer]
go (Word
k Word -> Word -> Word
forall a. Num a => a -> a -> a
- Word
1))

    -- Hensel lifting
    -- https://en.wikipedia.org/wiki/Hensel%27s_lemma#Hensel_lifting
    liftRoot :: Word -> Integer -> [Integer]
    liftRoot :: Word -> Integer -> [Integer]
liftRoot Word
k Integer
r = case Integer -> Integer -> Maybe Integer
recipMod (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
b) Integer
pk of
      Maybe Integer
Nothing -> case Integer
fr of
        Integer
0 -> (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
i -> Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
pk Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
p' Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
i) [Integer
0 .. Integer
p' Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1]
        Integer
_ -> []
      Just Integer
invDeriv -> [(Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
fr Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
invDeriv) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
pk]
      where
        pk :: Integer
pk = Integer
p' Integer -> Word -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ Word
k
        fr :: Integer
fr = (Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
b Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
c) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
pk

    p' :: Integer
    p' :: Integer
p' = Prime Integer -> Integer
forall a. Prime a -> a
unPrime Prime Integer
p

solveQuadraticPrime
  :: Integer
  -> Integer
  -> Integer
  -> Prime Integer
  -> [Integer]
solveQuadraticPrime :: Integer -> Integer -> Integer -> Prime Integer -> [Integer]
solveQuadraticPrime Integer
a Integer
b Integer
c (Prime Integer -> Integer
forall a. Prime a -> a
unPrime -> Integer
2 :: Integer)
  = case (Integer -> Bool
forall a. Integral a => a -> Bool
even Integer
c, Integer -> Bool
forall a. Integral a => a -> Bool
even (Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
b)) of
    (Bool
True, Bool
True) -> [Integer
0, Integer
1]
    (Bool
True, Bool
_)    -> [Integer
0]
    (Bool
_, Bool
False)   -> [Integer
1]
    (Bool, Bool)
_            -> []
solveQuadraticPrime Integer
a Integer
b Integer
c Prime Integer
p
  | Integer
a Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
p' Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0
  = Integer -> Integer -> Integer -> [Integer]
solveLinear' Integer
p' Integer
b Integer
c
  | Bool
otherwise
  = (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
n -> (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
b) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer -> Integer -> Integer
recipModInteger (Integer
2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
a) Integer
p' Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
p')
  ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Integer -> Prime Integer -> [Integer]
sqrtsModPrime (Integer
b Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
b Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
4 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
c) Prime Integer
p
    where
      p' :: Integer
      p' :: Integer
p' = Prime Integer -> Integer
forall a. Prime a -> a
unPrime Prime Integer
p