{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
module Math.NumberTheory.Roots.Squares.Internal
( karatsubaSqrt
, isqrtA
) where
import Data.Bits (finiteBitSize, unsafeShiftL, unsafeShiftR, (.&.), (.|.))
import GHC.Exts (Int(..), Int#, uncheckedIShiftRA#, isTrue#, int2Double#, sqrtDouble#, double2Int#, (<#), (*#), (-#))
import GHC.Integer.GMP.Internals (Integer(..), shiftLInteger, shiftRInteger, sizeofBigNat#)
import GHC.Integer.Logarithms (integerLog2#)
{-# SPECIALISE isqrtA :: Integer -> Integer #-}
isqrtA :: Integral a => a -> a
isqrtA :: a -> a
isqrtA a
0 = a
0
isqrtA a
n = a -> a -> a
forall a. Integral a => a -> a -> a
heron a
n (Integer -> a
forall a. Num a => Integer -> a
fromInteger (Integer -> a) -> (a -> Integer) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer
appSqrt (Integer -> Integer) -> (a -> Integer) -> a -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
n)
{-# SPECIALISE heron :: Integer -> Integer -> Integer #-}
heron :: Integral a => a -> a -> a
heron :: a -> a -> a
heron a
n a
a = a -> a
go (a -> a
step a
a)
where
step :: a -> a
step a
k = (a
k a -> a -> a
forall a. Num a => a -> a -> a
+ a
n a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
k) a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
2
go :: a -> a
go a
k
| a
m a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
k = a -> a
go a
m
| Bool
otherwise = a
k
where
m :: a
m = a -> a
step a
k
appSqrt :: Integer -> Integer
appSqrt :: Integer -> Integer
appSqrt (S# Int#
i#) = Int# -> Integer
S# (Double# -> Int#
double2Int# (Double# -> Double#
sqrtDouble# (Int# -> Double#
int2Double# Int#
i#)))
appSqrt n :: Integer
n@(Jp# BigNat
bn#)
| Int# -> Bool
isTrue# ((BigNat -> Int#
sizeofBigNat# BigNat
bn#) Int# -> Int# -> Int#
<# Int#
thresh#) =
Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
n :: Double)
| Bool
otherwise = case Integer -> Int#
integerLog2# Integer
n of
Int#
l# -> case Int# -> Int# -> Int#
uncheckedIShiftRA# Int#
l# Int#
1# Int# -> Int# -> Int#
-# Int#
47# of
Int#
h# -> case Integer -> Int# -> Integer
shiftRInteger Integer
n (Int#
2# Int# -> Int# -> Int#
*# Int#
h#) of
Integer
m -> case Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
m :: Double) of
Integer
r -> Integer -> Int# -> Integer
shiftLInteger Integer
r Int#
h#
where
thresh# :: Int#
thresh# :: Int#
thresh# = if Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 then Int#
5# else Int#
9#
appSqrt Integer
_ = [Char] -> Integer
forall a. HasCallStack => [Char] -> a
error [Char]
"integerSquareRoot': negative argument"
karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt Integer
0 = (Integer
0, Integer
0)
karatsubaSqrt Integer
n
| Int
lgN Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2300 =
let s :: Integer
s = Integer -> Integer
forall a. Integral a => a -> a
isqrtA Integer
n in (Integer
s, Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
s)
| Bool
otherwise =
if Int
lgN Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 then
Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n)
else
let n' :: Integer
n' = Integer
n Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
2
(Integer
s, Integer
r) = Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n')
r' :: Integer
r' | Integer
s Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = Integer
r
| Bool
otherwise = Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer
double Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1
in (Integer
s Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1, Integer
r' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2)
where
k :: Int
k = Int
lgN Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
lgN :: Int
lgN = Int# -> Int
I# (Integer -> Int#
integerLog2# Integer
n)
karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Integer
a3, Integer
a2, Integer
a1, Integer
a0)
| Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 = (Integer
s, Integer
r)
| Bool
otherwise = (Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1, Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer
double Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
where
r :: Integer
r = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
u Integer
a0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
q Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
q
s :: Integer
s = Integer
s' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
q
(Integer
q, Integer
u) = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
r' Integer
a1 Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Integer -> Integer
double Integer
s'
(Integer
s', Integer
r') = Integer -> (Integer, Integer)
karatsubaSqrt (Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
a3 Integer
a2)
cat :: a -> a -> a
cat a
x a
y = a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
y
{-# INLINE cat #-}
karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n0 = (Integer
a3, Integer
a2, Integer
a1, Integer
a0)
where
a3 :: Integer
a3 = Integer
n3
n3 :: Integer
n3 = Integer
n2 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
a2 :: Integer
a2 = Integer
n2 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
n2 :: Integer
n2 = Integer
n1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
a1 :: Integer
a1 = Integer
n1 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
n1 :: Integer
n1 = Integer
n0 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
a0 :: Integer
a0 = Integer
n0 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
m :: Integer
m = Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1
double :: Integer -> Integer
double :: Integer -> Integer
double Integer
x = Integer
x Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1
{-# INLINE double #-}