-- Module for Diophantine Equations and related functions

module Math.NumberTheory.Diophantine
  ( cornacchiaPrimitive
  , cornacchia
  )
where

import Data.List.Infinite (Infinite(..))
import qualified Data.List.Infinite as Inf

import           Math.NumberTheory.Moduli.Sqrt  ( sqrtsModFactorisation )
import           Math.NumberTheory.Primes       ( factorise
                                                , unPrime
                                                , UniqueFactorisation
                                                )
import           Math.NumberTheory.Roots        ( integerSquareRoot )
import           Math.NumberTheory.Utils.FromIntegral

-- | See `cornacchiaPrimitive`, this is the internal algorithm implementation
-- | as described at https://en.wikipedia.org/wiki/Cornacchia%27s_algorithm 
cornacchiaPrimitive' :: Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive' :: Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive' Integer
d Integer
m = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
  (Integer -> [(Integer, Integer)]
findSolution forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Infinite a -> a
Inf.head forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> Infinite a -> Infinite a
Inf.dropWhile (\Integer
r -> Integer
r forall a. Num a => a -> a -> a
* Integer
r forall a. Ord a => a -> a -> Bool
>= Integer
m) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer -> Infinite Integer
gcdSeq Integer
m)
  [Integer]
roots
 where
  roots :: [Integer]
  roots :: [Integer]
roots = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Ord a => a -> a -> Bool
<= Integer
m forall a. Integral a => a -> a -> a
`div` Integer
2) forall a b. (a -> b) -> a -> b
$ Integer -> [(Prime Integer, Word)] -> [Integer]
sqrtsModFactorisation (Integer
m forall a. Num a => a -> a -> a
- Integer
d) (forall a. UniqueFactorisation a => a -> [(Prime a, Word)]
factorise Integer
m)

  gcdSeq :: Integer -> Integer -> Infinite Integer
  gcdSeq :: Integer -> Integer -> Infinite Integer
gcdSeq Integer
a Integer
b = Integer
a forall a. a -> Infinite a -> Infinite a
:< Integer -> Integer -> Infinite Integer
gcdSeq Integer
b (forall a. Integral a => a -> a -> a
mod Integer
a Integer
b)

  -- If s = sqrt((m - r*r) / d) is an integer then (r, s) is a solution
  findSolution :: Integer -> [(Integer, Integer)]
  findSolution :: Integer -> [(Integer, Integer)]
findSolution Integer
r = [ (Integer
r, Integer
s) | Integer
rem1 forall a. Eq a => a -> a -> Bool
== Integer
0 Bool -> Bool -> Bool
&& Integer
s forall a. Num a => a -> a -> a
* Integer
s forall a. Eq a => a -> a -> Bool
== Integer
s2 ]
   where
    (Integer
s2, Integer
rem1) = forall a. Integral a => a -> a -> (a, a)
divMod (Integer
m forall a. Num a => a -> a -> a
- Integer
r forall a. Num a => a -> a -> a
* Integer
r) Integer
d
    s :: Integer
s          = forall a. Integral a => a -> a
integerSquareRoot Integer
s2

-- | Finds all primitive solutions (x,y) to the diophantine equation 
-- |    x^2 + d*y^2 = m
-- | when 1 <= d < m and gcd(d,m)=1
-- | Given m is square free these are all the positive integer solutions
cornacchiaPrimitive :: Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive :: Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive Integer
d Integer
m
  | Bool -> Bool
not (Integer
1 forall a. Ord a => a -> a -> Bool
<= Integer
d Bool -> Bool -> Bool
&& Integer
d forall a. Ord a => a -> a -> Bool
< Integer
m) = forall a. HasCallStack => [Char] -> a
error [Char]
"precondition failed: 1 <= d < m"
  | forall a. Integral a => a -> a -> a
gcd Integer
d Integer
m forall a. Eq a => a -> a -> Bool
/= Integer
1          = forall a. HasCallStack => [Char] -> a
error [Char]
"precondition failed: d and m coprime"
  |
  -- If d=1 then the algorithm doesn't generate symmetrical pairs 
    Integer
d forall a. Eq a => a -> a -> Bool
== Integer
1                = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {b}. Eq b => (b, b) -> [(b, b)]
genPairs [(Integer, Integer)]
solutions
  | Bool
otherwise             = [(Integer, Integer)]
solutions
 where
  solutions :: [(Integer, Integer)]
solutions = Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive' Integer
d Integer
m
  genPairs :: (b, b) -> [(b, b)]
genPairs (b
x, b
y) = if b
x forall a. Eq a => a -> a -> Bool
== b
y then [(b
x, b
y)] else [(b
x, b
y), (b
y, b
x)]

-- Find numbers whose square is a factor of the input
squareFactors :: UniqueFactorisation a => a -> [a]
squareFactors :: forall a. UniqueFactorisation a => a -> [a]
squareFactors = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a}. Num a => [a] -> (Prime a, Word) -> [a]
squareProducts [a
1] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. UniqueFactorisation a => a -> [(Prime a, Word)]
factorise
 where
  squareProducts :: [a] -> (Prime a, Word) -> [a]
squareProducts [a]
acc (Prime a, Word)
f = [ a
a forall a. Num a => a -> a -> a
* a
b | a
a <- [a]
acc, a
b <- forall {b}. Num b => (Prime b, Word) -> [b]
squarePowers (Prime a, Word)
f ]
  squarePowers :: (Prime b, Word) -> [b]
squarePowers (Prime b
p, Word
a) = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Prime a -> a
unPrime Prime b
p forall a b. (Num a, Integral b) => a -> b -> a
^) [Int
0 .. Word -> Int
wordToInt Word
a forall a. Integral a => a -> a -> a
`div` Int
2]

-- | Finds all positive integer solutions (x,y) to the
-- | diophantine equation:
-- |    x^2 + d*y^2 = m
-- | when 1 <= d < m and gcd(d,m)=1
cornacchia :: Integer -> Integer -> [(Integer, Integer)]
cornacchia :: Integer -> Integer -> [(Integer, Integer)]
cornacchia Integer
d Integer
m
  | Bool -> Bool
not (Integer
1 forall a. Ord a => a -> a -> Bool
<= Integer
d Bool -> Bool -> Bool
&& Integer
d forall a. Ord a => a -> a -> Bool
< Integer
m) = forall a. HasCallStack => [Char] -> a
error [Char]
"precondition failed: 1 <= d < m"
  | forall a. Integral a => a -> a -> a
gcd Integer
d Integer
m forall a. Eq a => a -> a -> Bool
/= Integer
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"precondition failed: d and m coprime"
  | Bool
otherwise = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Integer, Integer) -> [(Integer, Integer)]
solve forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> a -> Bool
> Integer
d) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Integer, Integer)]
candidates
 where
  candidates :: [(Integer, Integer)]
candidates = forall a b. (a -> b) -> [a] -> [b]
map (\Integer
sf -> (Integer
sf, Integer
m forall a. Integral a => a -> a -> a
`div` (Integer
sf forall a. Num a => a -> a -> a
* Integer
sf))) (forall a. UniqueFactorisation a => a -> [a]
squareFactors Integer
m)
  solve :: (Integer, Integer) -> [(Integer, Integer)]
solve (Integer
sf, Integer
m') = forall a b. (a -> b) -> [a] -> [b]
map (\(Integer
x, Integer
y) -> (Integer
x forall a. Num a => a -> a -> a
* Integer
sf, Integer
y forall a. Num a => a -> a -> a
* Integer
sf)) (Integer -> Integer -> [(Integer, Integer)]
cornacchiaPrimitive Integer
d Integer
m')