-- |
-- Module:      Math.NumberTheory.Roots.Squares.Internal
-- Copyright:   (c) 2011 Daniel Fischer, 2016-2020 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Internal functions dealing with square roots. End-users should not import this module.

{-# LANGUAGE BangPatterns     #-}
{-# LANGUAGE MagicHash        #-}

module Math.NumberTheory.Roots.Squares.Internal
  ( karatsubaSqrt
  , isqrtA
  ) where

import Data.Bits (finiteBitSize, unsafeShiftL, unsafeShiftR, (.&.), (.|.))

import GHC.Exts (Int(..), Int#, uncheckedIShiftRA#, isTrue#, int2Double#, sqrtDouble#, double2Int#, (<#), (*#), (-#))
import GHC.Integer.GMP.Internals (Integer(..), shiftLInteger, shiftRInteger, sizeofBigNat#)
import GHC.Integer.Logarithms (integerLog2#)

-- Find approximation to square root in 'Integer', then
-- find the integer square root by the integer variant
-- of Heron's method. Takes only a handful of steps
-- unless the input is really large.
{-# SPECIALISE isqrtA :: Integer -> Integer #-}
isqrtA :: Integral a => a -> a
isqrtA :: a -> a
isqrtA a
0 = a
0
isqrtA a
n = a -> a -> a
forall a. Integral a => a -> a -> a
heron a
n (Integer -> a
forall a. Num a => Integer -> a
fromInteger (Integer -> a) -> (a -> Integer) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Integer
appSqrt (Integer -> Integer) -> (a -> Integer) -> a -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
n)

-- Heron's method for integers. First make one step to ensure
-- the value we're working on is @>= r@, then we have
-- @k == r@ iff @k <= step k@.
{-# SPECIALISE heron :: Integer -> Integer -> Integer #-}
heron :: Integral a => a -> a -> a
heron :: a -> a -> a
heron a
n a
a = a -> a
go (a -> a
step a
a)
      where
        step :: a -> a
step a
k = (a
k a -> a -> a
forall a. Num a => a -> a -> a
+ a
n a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
k) a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
2
        go :: a -> a
go a
k
            | a
m a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
k     = a -> a
go a
m
            | Bool
otherwise = a
k
              where
                m :: a
m = a -> a
step a
k

-- Find a fairly good approximation to the square root.
-- At most one off for small Integers, about 48 bits should be correct
-- for large Integers.
appSqrt :: Integer -> Integer
appSqrt :: Integer -> Integer
appSqrt (S# Int#
i#) = Int# -> Integer
S# (Double# -> Int#
double2Int# (Double# -> Double#
sqrtDouble# (Int# -> Double#
int2Double# Int#
i#)))
appSqrt n :: Integer
n@(Jp# BigNat
bn#)
    | Int# -> Bool
isTrue# ((BigNat -> Int#
sizeofBigNat# BigNat
bn#) Int# -> Int# -> Int#
<# Int#
thresh#) =
          Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
n :: Double)
    | Bool
otherwise = case Integer -> Int#
integerLog2# Integer
n of
                    Int#
l# -> case Int# -> Int# -> Int#
uncheckedIShiftRA# Int#
l# Int#
1# Int# -> Int# -> Int#
-# Int#
47# of
                            Int#
h# -> case Integer -> Int# -> Integer
shiftRInteger Integer
n (Int#
2# Int# -> Int# -> Int#
*# Int#
h#) of
                                    Integer
m -> case Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor (Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
m :: Double) of
                                            Integer
r -> Integer -> Int# -> Integer
shiftLInteger Integer
r Int#
h#
    where
        -- threshold for shifting vs. direct fromInteger
        -- we shift when we expect more than 256 bits
        thresh# :: Int#
        thresh# :: Int#
thresh# = if Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
64 then Int#
5# else Int#
9#
-- There's already a check for negative in integerSquareRoot,
-- but integerSquareRoot' is exported directly too.
appSqrt Integer
_ = [Char] -> Integer
forall a. HasCallStack => [Char] -> a
error [Char]
"integerSquareRoot': negative argument"


-- Integer square root with remainder, using the Karatsuba Square Root
-- algorithm from
-- Paul Zimmermann. Karatsuba Square Root. [Research Report] RR-3805, 1999,
-- pp.8. <inria-00072854>

karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt :: Integer -> (Integer, Integer)
karatsubaSqrt Integer
0 = (Integer
0, Integer
0)
karatsubaSqrt Integer
n
    | Int
lgN Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2300 =
        let s :: Integer
s = Integer -> Integer
forall a. Integral a => a -> a
isqrtA Integer
n in (Integer
s, Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
s)
    | Bool
otherwise =
        if Int
lgN Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 then
            Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n)
        else
            -- before we split n into 4 part we must ensure that the first part
            -- is at least 2^k/4, since this doesn't happen here we scale n by
            -- multiplying it by 4
            let n' :: Integer
n' = Integer
n Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
2
                (Integer
s, Integer
r) = Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n')
                r' :: Integer
r' | Integer
s Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = Integer
r
                   | Bool
otherwise = Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer
double Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1
            in  (Integer
s Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1, Integer
r' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2)
  where
    k :: Int
k = Int
lgN Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    lgN :: Int
lgN = Int# -> Int
I# (Integer -> Int#
integerLog2# Integer
n)

karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep :: Int -> (Integer, Integer, Integer, Integer) -> (Integer, Integer)
karatsubaStep Int
k (Integer
a3, Integer
a2, Integer
a1, Integer
a0)
    | Integer
r Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 = (Integer
s, Integer
r)
    | Bool
otherwise = (Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1, Integer
r Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer -> Integer
double Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
  where
    r :: Integer
r = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
u Integer
a0 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
q Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
q
    s :: Integer
s = Integer
s' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
q
    (Integer
q, Integer
u) = Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
r' Integer
a1 Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Integer -> Integer
double Integer
s'
    (Integer
s', Integer
r') = Integer -> (Integer, Integer)
karatsubaSqrt (Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
cat Integer
a3 Integer
a2)
    cat :: a -> a -> a
cat a
x a
y = a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
y
    {-# INLINE cat #-}

karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit :: Int -> Integer -> (Integer, Integer, Integer, Integer)
karatsubaSplit Int
k Integer
n0 = (Integer
a3, Integer
a2, Integer
a1, Integer
a0)
  where
    a3 :: Integer
a3 = Integer
n3
    n3 :: Integer
n3 = Integer
n2 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
    a2 :: Integer
a2 = Integer
n2 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
    n2 :: Integer
n2 = Integer
n1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
    a1 :: Integer
a1 = Integer
n1 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
    n1 :: Integer
n1 = Integer
n0 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
k
    a0 :: Integer
a0 = Integer
n0 Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. Integer
m
    m :: Integer
m = Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1

double :: Integer -> Integer
double :: Integer -> Integer
double Integer
x = Integer
x Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1
{-# INLINE double #-}