-- Module for Diophantine Equations and related functions

module Math.NumberTheory.Diophantine
  ( cornacchiaPrimitive
  , cornacchia
  )
where

import           Math.NumberTheory.Moduli.Sqrt  ( sqrtsModFactorisation )
import           Math.NumberTheory.Primes       ( factorise
                                                , unPrime
                                                , UniqueFactorisation
                                                )
import           Math.NumberTheory.Roots        ( integerSquareRoot )
import           Math.NumberTheory.Utils.FromIntegral

-- | See `cornacchiaPrimitive`, this is the internal algorithm implementation
-- | as described at https://en.wikipedia.org/wiki/Cornacchia%27s_algorithm 
cornacchiaPrimitive' :: Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive' :: Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive' Integer
d Integer
m = (Integer -> [(Integer, Integer)])
-> [Integer] -> [(Integer, Integer)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
  (Integer -> [(Integer, Integer)]
findSolution (Integer -> [(Integer, Integer)])
-> (Integer -> Integer) -> Integer -> [(Integer, Integer)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Integer] -> Integer
forall a. [a] -> a
head ([Integer] -> Integer)
-> (Integer -> [Integer]) -> Integer -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (\Integer
r -> Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
m) ([Integer] -> [Integer])
-> (Integer -> [Integer]) -> Integer -> [Integer]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer -> [Integer]
forall t. Integral t => t -> t -> [t]
gcdSeq Integer
m)
  [Integer]
roots
 where
  roots :: [Integer]
roots = (Integer -> Bool) -> [Integer] -> [Integer]
forall a. (a -> Bool) -> [a] -> [a]
filter (Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
m Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
2) ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ Integer -> [(Prime Integer, Word)] -> [Integer]
sqrtsModFactorisation (Integer
m Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
d) (Integer -> [(Prime Integer, Word)]
forall a. UniqueFactorisation a => a -> [(Prime a, Word)]
factorise Integer
m)
  gcdSeq :: t -> t -> [t]
gcdSeq t
a t
b = t
a t -> [t] -> [t]
forall a. a -> [a] -> [a]
: t -> t -> [t]
gcdSeq t
b (t -> t -> t
forall a. Integral a => a -> a -> a
mod t
a t
b)
  -- If s = sqrt((m - r*r) / d) is an integer then (r, s) is a solution
  findSolution :: Integer -> [(Integer, Integer)]
findSolution Integer
r = [ (Integer
r, Integer
s) | Integer
rem1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 Bool -> Bool -> Bool
&& Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
s2 ]
   where
    (Integer
s2, Integer
rem1) = Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
divMod (Integer
m Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
r) Integer
d
    s :: Integer
s          = Integer -> Integer
forall a. Integral a => a -> a
integerSquareRoot Integer
s2

-- | Finds all primitive solutions (x,y) to the diophantine equation 
-- |    x^2 + d*y^2 = m
-- | when 1 <= d < m and gcd(d,m)=1
-- | Given m is square free these are all the positive integer solutions
cornacchiaPrimitive :: Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive :: Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive Integer
d Integer
m
  | Bool -> Bool
not (Integer
1 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
d Bool -> Bool -> Bool
&& Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
m) = [Char] -> [(Integer, Integer)]
forall a. HasCallStack => [Char] -> a
error [Char]
"precondition failed: 1 <= d < m"
  | Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd Integer
d Integer
m Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
1          = [Char] -> [(Integer, Integer)]
forall a. HasCallStack => [Char] -> a
error [Char]
"precondition failed: d and m coprime"
  |
  -- If d=1 then the algorithm doesn't generate symmetrical pairs 
    Integer
d Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
1                = ((Integer, Integer) -> [(Integer, Integer)])
-> [(Integer, Integer)] -> [(Integer, Integer)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Integer, Integer) -> [(Integer, Integer)]
forall b. Eq b => (b, b) -> [(b, b)]
genPairs [(Integer, Integer)]
solutions
  | Bool
otherwise             = [(Integer, Integer)]
solutions
 where
  solutions :: [(Integer, Integer)]
solutions = Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive' Integer
d Integer
m
  genPairs :: (b, b) -> [(b, b)]
genPairs (b
x, b
y) = if b
x b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== b
y then [(b
x, b
y)] else [(b
x, b
y), (b
y, b
x)]

-- Find numbers whose square is a factor of the input
squareFactors :: UniqueFactorisation a => a -> [a]
squareFactors :: a -> [a]
squareFactors = ([a] -> (Prime a, Word) -> [a]) -> [a] -> [(Prime a, Word)] -> [a]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [a] -> (Prime a, Word) -> [a]
forall a. Num a => [a] -> (Prime a, Word) -> [a]
squareProducts [a
1] ([(Prime a, Word)] -> [a]) -> (a -> [(Prime a, Word)]) -> a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> [(Prime a, Word)]
forall a. UniqueFactorisation a => a -> [(Prime a, Word)]
factorise
 where
  squareProducts :: [a] -> (Prime a, Word) -> [a]
squareProducts [a]
acc (Prime a, Word)
f = [ a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
b | a
a <- [a]
acc, a
b <- (Prime a, Word) -> [a]
forall b. Num b => (Prime b, Word) -> [b]
squarePowers (Prime a, Word)
f ]
  squarePowers :: (Prime b, Word) -> [b]
squarePowers (Prime b
p, Word
a) = (Int -> b) -> [Int] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map (Prime b -> b
forall a. Prime a -> a
unPrime Prime b
p b -> Int -> b
forall a b. (Num a, Integral b) => a -> b -> a
^) [Int
0 .. Word -> Int
wordToInt Word
a Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2]

-- | Finds all positive integer solutions (x,y) to the
-- | diophantine equation:
-- |    x^2 + d*y^2 = m
-- | when 1 <= d < m and gcd(d,m)=1
cornacchia :: Integer -> Integer -> [(Integer, Integer)]
cornacchia :: Integer -> Integer -> [(Integer, Integer)]
cornacchia Integer
d Integer
m
  | Bool -> Bool
not (Integer
1 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
d Bool -> Bool -> Bool
&& Integer
d Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
m) = [Char] -> [(Integer, Integer)]
forall a. HasCallStack => [Char] -> a
error [Char]
"precondition failed: 1 <= d < m"
  | Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd Integer
d Integer
m Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
1 = [Char] -> [(Integer, Integer)]
forall a. HasCallStack => [Char] -> a
error [Char]
"precondition failed: d and m coprime"
  | Bool
otherwise = ((Integer, Integer) -> [(Integer, Integer)])
-> [(Integer, Integer)] -> [(Integer, Integer)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Integer, Integer) -> [(Integer, Integer)]
solve ([(Integer, Integer)] -> [(Integer, Integer)])
-> [(Integer, Integer)] -> [(Integer, Integer)]
forall a b. (a -> b) -> a -> b
$ ((Integer, Integer) -> Bool)
-> [(Integer, Integer)] -> [(Integer, Integer)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
d) (Integer -> Bool)
-> ((Integer, Integer) -> Integer) -> (Integer, Integer) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, Integer) -> Integer
forall a b. (a, b) -> b
snd) [(Integer, Integer)]
candidates
 where
  candidates :: [(Integer, Integer)]
candidates = (Integer -> (Integer, Integer))
-> [Integer] -> [(Integer, Integer)]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
sf -> (Integer
sf, Integer
m Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` (Integer
sf Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
sf))) (Integer -> [Integer]
forall a. UniqueFactorisation a => a -> [a]
squareFactors Integer
m)
  solve :: (Integer, Integer) -> [(Integer, Integer)]
solve (Integer
sf, Integer
m') = ((Integer, Integer) -> (Integer, Integer))
-> [(Integer, Integer)] -> [(Integer, Integer)]
forall a b. (a -> b) -> [a] -> [b]
map (\(Integer
x, Integer
y) -> (Integer
x Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
sf, Integer
y Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
sf)) (Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive Integer
d Integer
m')