{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
module Math.NumberTheory.Roots.Squares.Internal
( karatsubaSqrt
, isqrtA
) where
import Data.Bits (finiteBitSize, unsafeShiftL, unsafeShiftR, (.&.), (.|.))
import GHC.Exts (Int(..), Int#, uncheckedIShiftRA#, isTrue#, int2Double#, sqrtDouble#, double2Int#, (<#), (*#), (-#))
import GHC.Integer.GMP.Internals (Integer(..), shiftLInteger, shiftRInteger, sizeofBigNat#)
import GHC.Integer.Logarithms (integerLog2#)
{-# SPECIALISE isqrtA :: Integer -> Integer #-}
isqrtA :: Integral a => a -> a
isqrtA 0 = 0
isqrtA n = heron n (fromInteger . appSqrt . fromIntegral $ n)
{-# SPECIALISE heron :: Integer -> Integer -> Integer #-}
heron :: Integral a => a -> a -> a
heron n a = go (step a)
where
step k = (k + n `quot` k) `quot` 2
go k
| m < k = go m
| otherwise = k
where
m = step k
appSqrt :: Integer -> Integer
appSqrt (S# i#) = S# (double2Int# (sqrtDouble# (int2Double# i#)))
appSqrt n@(Jp# bn#)
| isTrue# ((sizeofBigNat# bn#) <# thresh#) =
floor (sqrt $ fromInteger n :: Double)
| otherwise = case integerLog2# n of
l# -> case uncheckedIShiftRA# l# 1# -# 47# of
h# -> case shiftRInteger n (2# *# h#) of
m -> case floor (sqrt $ fromInteger m :: Double) of
r -> shiftLInteger r h#
where
thresh# :: Int#
thresh# = if finiteBitSize (0 :: Word) == 64 then 5# else 9#
appSqrt _ = error "integerSquareRoot': negative argument"
karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt 0 = (0, 0)
karatsubaSqrt n
| lgN < 2300 =
let s = isqrtA n in (s, n - s * s)
| otherwise =
if lgN .&. 2 /= 0 then
karatsubaStep k (karatsubaSplit k n)
else
let n' = n `unsafeShiftL` 2
(s, r) = karatsubaStep k (karatsubaSplit k n')
r' | s .&. 1 == 0 = r
| otherwise = r + double s - 1
in (s `unsafeShiftR` 1, r' `unsafeShiftR` 2)
where
k = lgN `unsafeShiftR` 2 + 1
lgN = I# (integerLog2# n)
karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep k (a3, a2, a1, a0)
| r >= 0 = (s, r)
| otherwise = (s - 1, r + double s - 1)
where
r = cat u a0 - q * q
s = s' `unsafeShiftL` k + q
(q, u) = cat r' a1 `quotRem` double s'
(s', r') = karatsubaSqrt (cat a3 a2)
cat x y = x `unsafeShiftL` k .|. y
{-# INLINE cat #-}
karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit k n0 = (a3, a2, a1, a0)
where
a3 = n3
n3 = n2 `unsafeShiftR` k
a2 = n2 .&. m
n2 = n1 `unsafeShiftR` k
a1 = n1 .&. m
n1 = n0 `unsafeShiftR` k
a0 = n0 .&. m
m = 1 `unsafeShiftL` k - 1
double :: Integer -> Integer
double x = x `unsafeShiftL` 1
{-# INLINE double #-}