{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}
{-# OPTIONS_GHC -fno-warn-deprecations #-}
module Math.NumberTheory.GCD
( binaryGCD
, extendedGCD
, coprime
) where
import Data.Bits
import Data.Semigroup
import GHC.Word
import GHC.Int
import Math.NumberTheory.GCD.LowLevel
import Math.NumberTheory.Utils
#include "MachDeps.h"
{-# RULES
"binaryGCD/Int" binaryGCD = gcdInt
"binaryGCD/Word" binaryGCD = gcdWord
"binaryGCD/Int8" binaryGCD = gi8
"binaryGCD/Int16" binaryGCD = gi16
"binaryGCD/Int32" binaryGCD = gi32
"binaryGCD/Word8" binaryGCD = gw8
"binaryGCD/Word16" binaryGCD = gw16
"binaryGCD/Word32" binaryGCD = gw32
#-}
#if WORD_SIZE_IN_BITS == 64
gi64 :: Int64 -> Int64 -> Int64
gi64 (I64# x#) (I64# y#) = I64# (gcdInt# x# y#)
gw64 :: Word64 -> Word64 -> Word64
gw64 (W64# x#) (W64# y#) = W64# (gcdWord# x# y#)
{-# RULES
"binaryGCD/Int64" binaryGCD = gi64
"binaryGCD/Word64" binaryGCD = gw64
#-}
#endif
{-# INLINE [1] binaryGCD #-}
binaryGCD :: (Integral a, Bits a) => a -> a -> a
binaryGCD = binaryGCDImpl
{-# DEPRECATED binaryGCD "Use 'Math.NumberTheory.Euclidean.gcd'" #-}
#if WORD_SIZE_IN_BITS < 64
{-# SPECIALISE binaryGCDImpl :: Word64 -> Word64 -> Word64,
Int64 -> Int64 -> Int64 #-}
#endif
{-# SPECIALISE binaryGCDImpl :: Integer -> Integer -> Integer #-}
binaryGCDImpl :: (Integral a, Bits a) => a -> a -> a
binaryGCDImpl a 0 = abs a
binaryGCDImpl 0 b = abs b
binaryGCDImpl a b =
case shiftToOddCount a' of
(!za, !oa) ->
case shiftToOddCount b' of
(!zb, !ob) -> gcdOdd (abs oa) (abs ob) `shiftL` min za zb
where
a' = abs a
b' = abs b
{-# SPECIALISE extendedGCD :: Int -> Int -> (Int, Int, Int),
Word -> Word -> (Word, Word, Word),
Integer -> Integer -> (Integer, Integer, Integer)
#-}
extendedGCD :: Integral a => a -> a -> (a, a, a)
extendedGCD a b = (d, u, v)
where
(d, x, y) = eGCD 0 1 1 0 (abs a) (abs b)
u | a < 0 = negate x
| otherwise = x
v | b < 0 = negate y
| otherwise = y
eGCD !n1 o1 !n2 o2 r s
| s == 0 = (r, o1, o2)
| otherwise = case r `quotRem` s of
(q, t) -> eGCD (o1 - q*n1) n1 (o2 - q*n2) n2 s t
{-# DEPRECATED extendedGCD "Use 'Math.NumberTheory.Euclidean.extendedGCD'" #-}
{-# RULES
"coprime/Int" coprime = coprimeInt
"coprime/Word" coprime = coprimeWord
"coprime/Int8" coprime = ci8
"coprime/Int16" coprime = ci16
"coprime/Int32" coprime = ci32
"coprime/Word8" coprime = cw8
"coprime/Word16" coprime = cw16
"coprime/Word32" coprime = cw32
#-}
#if WORD_SIZE_IN_BITS == 64
ci64 :: Int64 -> Int64 -> Bool
ci64 (I64# x#) (I64# y#) = coprimeInt# x# y#
cw64 :: Word64 -> Word64 -> Bool
cw64 (W64# x#) (W64# y#) = coprimeWord# x# y#
{-# RULES
"coprime/Int64" coprime = ci64
"coprime/Word64" coprime = cw64
#-}
#endif
{-# INLINE [1] coprime #-}
coprime :: (Integral a, Bits a) => a -> a -> Bool
coprime = coprimeImpl
{-# DEPRECATED coprime "Use 'Math.NumberTheory.Euclidean.coprime'" #-}
#if WORD_SIZE_IN_BITS < 64
{-# SPECIALISE coprimeImpl :: Word64 -> Word64 -> Bool,
Int64 -> Int64 -> Bool #-}
#endif
{-# SPECIALISE coprimeImpl :: Integer -> Integer -> Bool #-}
coprimeImpl :: (Integral a, Bits a) => a -> a -> Bool
coprimeImpl a b =
(a' == 1 || b' == 1)
|| (a' /= 0 && b' /= 0 && ((a .|. b) .&. 1) == 1
&& gcdOdd (abs (shiftToOdd a')) (abs (shiftToOdd b')) == 1)
where
a' = abs a
b' = abs b
{-# INLINE gcdOdd #-}
gcdOdd :: (Integral a, Bits a) => a -> a -> a
gcdOdd a b
| a == 1 || b == 1 = 1
| a < b = oddGCD b a
| a > b = oddGCD a b
| otherwise = a
{-# SPECIALISE oddGCD :: Integer -> Integer -> Integer #-}
#if WORD_SIZE_IN_BITS < 64
{-# SPECIALISE oddGCD :: Int64 -> Int64 -> Int64,
Word64 -> Word64 -> Word64
#-}
#endif
oddGCD :: (Integral a, Bits a) => a -> a -> a
oddGCD a b =
case shiftToOdd (a-b) of
1 -> 1
c | c < b -> oddGCD b c
| c > b -> oddGCD c b
| otherwise -> c
gi8 :: Int8 -> Int8 -> Int8
gi8 (I8# x#) (I8# y#) = I8# (gcdInt# x# y#)
gi16 :: Int16 -> Int16 -> Int16
gi16 (I16# x#) (I16# y#) = I16# (gcdInt# x# y#)
gi32 :: Int32 -> Int32 -> Int32
gi32 (I32# x#) (I32# y#) = I32# (gcdInt# x# y#)
gw8 :: Word8 -> Word8 -> Word8
gw8 (W8# x#) (W8# y#) = W8# (gcdWord# x# y#)
gw16 :: Word16 -> Word16 -> Word16
gw16 (W16# x#) (W16# y#) = W16# (gcdWord# x# y#)
gw32 :: Word32 -> Word32 -> Word32
gw32 (W32# x#) (W32# y#) = W32# (gcdWord# x# y#)
ci8 :: Int8 -> Int8 -> Bool
ci8 (I8# x#) (I8# y#) = coprimeInt# x# y#
ci16 :: Int16 -> Int16 -> Bool
ci16 (I16# x#) (I16# y#) = coprimeInt# x# y#
ci32 :: Int32 -> Int32 -> Bool
ci32 (I32# x#) (I32# y#) = coprimeInt# x# y#
cw8 :: Word8 -> Word8 -> Bool
cw8 (W8# x#) (W8# y#) = coprimeWord# x# y#
cw16 :: Word16 -> Word16 -> Bool
cw16 (W16# x#) (W16# y#) = coprimeWord# x# y#
cw32 :: Word32 -> Word32 -> Bool
cw32 (W32# x#) (W32# y#) = coprimeWord# x# y#