```-- |
-- Module:      Math.NumberTheory.GaussianIntegers
-- Licence:     MIT
-- Maintainer:  Chris Fredrickson <chris.p.fredrickson@gmail.com>
--
-- This module exports functions for manipulating Gaussian integers, including
-- computing their prime factorisations.
--

{-# LANGUAGE BangPatterns  #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TypeFamilies  #-}

GaussianInteger(..),
ι,
conjugate,
norm,
primes,
findPrime,
) where

import Control.DeepSeq (NFData)
import Data.Coerce
import Data.List (mapAccumL, partition)
import Data.Maybe (fromMaybe)
import Data.Ord (comparing)
import GHC.Generics

import qualified Math.NumberTheory.Euclidean as ED
import Math.NumberTheory.Moduli.Sqrt
import Math.NumberTheory.Powers (integerSquareRoot)
import Math.NumberTheory.Primes.Types
import qualified Math.NumberTheory.Primes.Sieve as Sieve
import qualified Math.NumberTheory.Primes.Testing as Testing
import qualified Math.NumberTheory.Primes  as U
import Math.NumberTheory.Utils              (mergeBy)
import Math.NumberTheory.Utils.FromIntegral

infix 6 :+
-- |A Gaussian integer is a+bi, where a and b are both integers.
data GaussianInteger = (:+) { real :: !Integer, imag :: !Integer }
deriving (Eq, Ord, Generic)

instance NFData GaussianInteger

-- |The imaginary unit, where
--
-- > ι .^ 2 == -1
ι :: GaussianInteger
ι = 0 :+ 1

instance Show GaussianInteger where
show (a :+ b)
| b == 0     = show a
| a == 0     = s ++ b'
| otherwise  = show a ++ op ++ b'
where
b' = if abs b == 1 then "ι" else show (abs b) ++ "*ι"
op = if b > 0 then "+" else "-"
s  = if b > 0 then "" else "-"

instance Num GaussianInteger where
(+) (a :+ b) (c :+ d) = (a + c) :+ (b + d)
(*) (a :+ b) (c :+ d) = (a * c - b * d) :+ (a * d + b * c)
abs = fst . absSignum
negate (a :+ b) = (-a) :+ (-b)
fromInteger n = n :+ 0
signum = snd . absSignum

absSignum :: GaussianInteger -> (GaussianInteger, GaussianInteger)
absSignum z@(a :+ b)
| a == 0 && b == 0 =   (z, 0)              -- origin
| a >  0 && b >= 0 =   (z, 1)              -- first quadrant: (0, inf) x [0, inf)i
| a <= 0 && b >  0 =   (b  :+ (-a), ι)     -- second quadrant: (-inf, 0] x (0, inf)i
| a <  0 && b <= 0 = ((-a) :+ (-b), -1)    -- third quadrant: (-inf, 0) x (-inf, 0]i
| otherwise        = ((-b) :+   a, -ι)     -- fourth quadrant: [0, inf) x (-inf, 0)i

instance ED.Euclidean GaussianInteger where
quotRem = divHelper quot
divMod  = divHelper div

divHelper
:: (Integer -> Integer -> Integer)
-> GaussianInteger
-> GaussianInteger
-> (GaussianInteger, GaussianInteger)
divHelper divide g h =
let nr :+ ni = g * conjugate h
denom = norm h
q = divide nr denom :+ divide ni denom
p = h * q
in (q, g - p)

-- |Conjugate a Gaussian integer.
conjugate :: GaussianInteger -> GaussianInteger
conjugate (r :+ i) = r :+ (-i)

-- |The square of the magnitude of a Gaussian integer.
norm :: GaussianInteger -> Integer
norm (x :+ y) = x * x + y * y

-- |Compute whether a given Gaussian integer is prime.
isPrime :: GaussianInteger -> Bool
isPrime g@(x :+ y)
| x == 0 && y /= 0 = abs y `mod` 4 == 3 && Testing.isPrime y
| y == 0 && x /= 0 = abs x `mod` 4 == 3 && Testing.isPrime x
| otherwise        = Testing.isPrime \$ norm g

-- |An infinite list of the Gaussian primes. Uses primes in Z to exhaustively
-- generate all Gaussian primes (up to associates), in order of ascending
-- magnitude.
primes :: [U.Prime GaussianInteger]
primes = coerce \$ (1 :+ 1) : mergeBy (comparing norm) l r
where
leftPrimes, rightPrimes :: [Prime Integer]
(leftPrimes, rightPrimes) = partition (\p -> unPrime p `mod` 4 == 3) (tail Sieve.primes)
l = [unPrime p :+ 0 | p <- leftPrimes]
r = [g | p <- rightPrimes, let Prime (x :+ y) = findPrime p, g <- [x :+ y, y :+ x]]

-- |Find a Gaussian integer whose norm is the given prime number
-- of form 4k + 1 using
-- <http://www.ams.org/journals/mcom/1972-26-120/S0025-5718-1972-0314745-6/S0025-5718-1972-0314745-6.pdf Hermite-Serret algorithm>.
findPrime :: Prime Integer -> U.Prime GaussianInteger
findPrime p = case sqrtsModPrime (-1) p of
[]    -> error "findPrime: an argument must be prime p = 4k + 1"
z : _ -> Prime \$ go (unPrime p) z -- Effectively we calculate gcdG' (p :+ 0) (z :+ 1)
where
sqrtp :: Integer
sqrtp = integerSquareRoot (unPrime p)

go :: Integer -> Integer -> GaussianInteger
go g h
| g <= sqrtp = g :+ h
| otherwise  = go h (g `mod` h)

-- | Compute the prime factorisation of a Gaussian integer. This is unique up to units (+/- 1, +/- i).
-- Unit factors are not included in the result.
factorise :: GaussianInteger -> [(Prime GaussianInteger, Word)]
factorise g = concat \$ snd \$ mapAccumL go g (U.factorise \$ norm g)
where
go :: GaussianInteger -> (Prime Integer, Word) -> (GaussianInteger, [(Prime GaussianInteger, Word)])
go z (Prime 2, e) = (divideByTwo z, [(Prime (1 :+ 1), e)])
go z (p, e)
| unPrime p `mod` 4 == 3
= let e' = e `quot` 2 in (z `quotI` (unPrime p ^ e'), [(Prime (unPrime p :+ 0), e')])
| otherwise
= (z', filter ((> 0) . snd) [(gp, k), (gp', k')])
where
gp = findPrime p
(k, k', z') = divideByPrime gp (unPrime p) e z
gp' = Prime (abs (conjugate (unPrime gp)))

-- | Remove all (1:+1) factors from the argument,
-- avoiding complex division.
divideByTwo :: GaussianInteger -> GaussianInteger
divideByTwo z@(x :+ y)
| even x, even y
= divideByTwo \$ z `quotI` 2
| odd x, odd y
= (x - y) `quot` 2 :+ (x + y) `quot` 2
| otherwise
= z

-- | Remove p and conj p factors from the argument,
-- avoiding complex division.
divideByPrime
:: Prime GaussianInteger -- ^ Gaussian prime p
-> Integer               -- ^ Precomputed norm of p, of form 4k + 1
-> Word                  -- ^ Expected number of factors (either p or conj p)
--   in Gaussian integer z
-> GaussianInteger       -- ^ Gaussian integer z
-> ( Word                -- Multiplicity of factor p in z
, Word                -- Multiplicity of factor conj p in z
, GaussianInteger     -- Remaining Gaussian integer
)
divideByPrime p np k = go k 0
where
go :: Word -> Word -> GaussianInteger -> (Word, Word, GaussianInteger)
go 0 d z = (d, d, z)
go c d z
| c >= 2
, Just z' <- z `quotEvenI` np
= go (c - 2) (d + 1) z'
go c d z = (d + d1, d + d2, z'')
where
(d1, z') = go1 c 0 z
d2 = c - d1
z'' = head \$ drop (wordToInt d2)
\$ iterate (\g -> fromMaybe err \$ (g * unPrime p) `quotEvenI` np) z'

go1 :: Word -> Word -> GaussianInteger -> (Word, GaussianInteger)
go1 0 d z = (d, z)
go1 c d z
| Just z' <- (z * conjugate (unPrime p)) `quotEvenI` np
= go1 (c - 1) (d + 1) z'
| otherwise
= (d, z)

err = error \$ "divideByPrime: malformed arguments" ++ show (p, np, k)

quotI :: GaussianInteger -> Integer -> GaussianInteger
quotI (x :+ y) n = (x `quot` n :+ y `quot` n)

quotEvenI :: GaussianInteger -> Integer -> Maybe GaussianInteger
quotEvenI (x :+ y) n
| xr == 0
, yr == 0
= Just (xq :+ yq)
| otherwise
= Nothing
where
(xq, xr) = x `quotRem` n
(yq, yr) = y `quotRem` n

-------------------------------------------------------------------------------

instance U.UniqueFactorisation GaussianInteger where
factorise 0 = []
factorise g = coerce \$ factorise g

isPrime g = if isPrime g then Just (Prime g) else Nothing
```