-- Copyright 2021 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     https://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

-- | Rounded base-2 logarithms of 'Integral' and 'Real' types.

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}

module Numeric.Logarithms
         ( -- * Real Logarithms
           log2Floor, log2Ceiling
           -- ** Advanced
         , log2Approx, log2With
           -- * Integer Logarithms
         , ilog2Floor, ilog2Ceiling
           -- ** Advanced
         , ilog2Approx, ilog2With
         ) where

import Data.Bits (bit)
import Data.Ratio (denominator, numerator)
import GHC.Exts (Int(I#))
import GHC.Integer (shiftLInteger)
import GHC.Integer.Logarithms (integerLog2#)
import GHC.Stack (HasCallStack)

integerLog2Floor :: Integer -> Int
integerLog2Floor :: Integer -> Int
integerLog2Floor Integer
x = Int# -> Int
I# (Integer -> Int#
integerLog2# Integer
x)

-- Derivation of log2 algorithms for Rational:
--
--   floor(log2(r))
-- = floor(log2(x/y)) with x, y the numerator and denominator of r.
-- = floor(log2((2^n * (1 + k/2^n)) / (2^m * (1 + j/2^m))))
--     with x = 2^n + k, k < 2^n, n an integer
--          y = 2^m + j, j < 2^m, m an integer
-- = floor(log2(2^n/2^m) + log2((1+k/2^n)/(1+j/2^m)))
-- = floor(n - m + log2((1+k/2^n)/(1+j/2^m)))
-- = n - m + floor(log2((1+k/2^n)/(1+j/2^m)))
-- = n - m + if (1+k/2^n) < (1+j/2^m) then -1 else 0
--     because 1+k/2^n and 1+j/2^m are both in [1, 2),
--     so their quotient is in (1/2, 2)
--     so log2 of their quotient is in (-1, 1)
--     so the floor of log2 of their quotient is -1 or 0:
--       -1 if the quotient is less than 1; 0 if the quotient is >= 1
-- = n - m + if k/2^n < j/2^m then -1 else 0
-- = n - m + if k*2^m < j*2^n then -1 else 0
-- = n - m + if (k<<m) < (j<<n) then -1 else 0
--
-- The same derivation holds for ceil(log2(r)) with adjustment values being 0
-- and 1 rather than -1 and 0, and the comparison being <= instead of <.

-- Returns a tuple @(e, num', den')@ s.t. @2^e * num' 'Data.Ratio.%' den' ==
-- num % den@, and @num' % den'@ in /(1\/2, 2)/ (exclusive at both bounds).
{-# INLINE splitLog2Unchecked #-}
splitLog2Unchecked :: Integer -> Integer -> (Int, Integer, Integer)
splitLog2Unchecked :: Integer -> Integer -> (Int, Integer, Integer)
splitLog2Unchecked Integer
num Integer
den = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m, Integer
num', Integer
den')
 where
  n :: Int
n = Integer -> Int
integerLog2Floor Integer
num
  m :: Int
m = Integer -> Int
integerLog2Floor Integer
den
  -- Note: the Bits instance for Integer doesn't define unsafeShiftL or shiftL,
  -- so it defaults to testing the sign of the argument and dispatching to left
  -- or right shifts. So, we call what its definition should have been instead.
  shl :: Integer -> Int -> Integer
shl Integer
x (I# Int#
s) = Integer -> Int# -> Integer
shiftLInteger Integer
x Int#
s
  (Integer
num', Integer
den') = if Int
m Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n
    then (Integer
num Integer -> Int -> Integer
`shl` (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
n), Integer
den)
    else (Integer
num, Integer
den Integer -> Int -> Integer
`shl` (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
m))

{-# INLINE log2Unchecked #-}
log2Unchecked :: Integer -> Integer -> (Int, Ordering)
log2Unchecked :: Integer -> Integer -> (Int, Ordering)
log2Unchecked Integer
num Integer
den =
  let (Int
lg, Integer
num', Integer
den') = Integer -> Integer -> (Int, Integer, Integer)
splitLog2Unchecked Integer
num Integer
den
  in  (Int
lg, Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Integer
num' Integer
den')

-- The underscore-suffixed variants here are inlined versions so that GHC will
-- implement the exported functions by big chunks of optimized code, but not
-- try to include that code in the module interface.

{-# INLINE log2Approx_ #-}
log2Approx_ :: (HasCallStack, Real a) => a -> (Int, Ordering)
log2Approx_ :: a -> (Int, Ordering)
log2Approx_ a
x =
  let !xr :: Rational
xr = a -> Rational
forall a. Real a => a -> Rational
toRational a
x
      !num :: Integer
num = Rational -> Integer
forall a. Ratio a -> a
numerator Rational
xr
      !den :: Integer
den = Rational -> Integer
forall a. Ratio a -> a
denominator Rational
xr
  in  if Integer
num Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0
        then [Char] -> (Int, Ordering)
forall a. HasCallStack => [Char] -> a
error [Char]
"log2With_: x <= 0"
        else Integer -> Integer -> (Int, Ordering)
log2Unchecked Integer
num Integer
den

{-# INLINE log2With_ #-}
log2With_ :: (HasCallStack, Real a) => (Int -> Ordering -> Int) -> a -> Int
log2With_ :: (Int -> Ordering -> Int) -> a -> Int
log2With_ Int -> Ordering -> Int
adj a
x =
  let !(Int
lg, Ordering
cmp) = a -> (Int, Ordering)
forall a. (HasCallStack, Real a) => a -> (Int, Ordering)
log2Approx_ a
x
  in  Int -> Ordering -> Int
adj Int
lg Ordering
cmp

-- | Returns an approximate base-2 logarithm of the argument.
--
-- The returned @Int@ is one of the two nearest integers to the exact result,
-- and the returned @Ordering@ tells whether the exact result is greater than,
-- equal to, or less than the @Int@.  `LT` means the exact result is less than
-- the rounded result.
--
-- This effectively gives results like "between 7 and 8" and leaves it up to
-- the caller to decide how (or whether) to round the result.  See also
-- `log2With` for a version that will always round, but leaves the particular
-- rounding strategy up to the caller.
log2Approx :: (HasCallStack, Real a) => a -> (Int, Ordering)
log2Approx :: a -> (Int, Ordering)
log2Approx = a -> (Int, Ordering)
forall a. (HasCallStack, Real a) => a -> (Int, Ordering)
log2Approx_

-- | Returns the base-2 logarithm with custom rounding.
--
-- The first parameter is a rounding adjustment function: given the approximate
-- result and an 'Ordering' indicating its relation to the exact result, return
-- the rounded result.
--
-- This could be useful if you want something other than floor or ceil, e.g.
-- round-towards-0, round-towards-even, etc.
log2With :: (HasCallStack, Real a) => (Int -> Ordering -> Int) -> a -> Int
log2With :: (Int -> Ordering -> Int) -> a -> Int
log2With = (Int -> Ordering -> Int) -> a -> Int
forall a.
(HasCallStack, Real a) =>
(Int -> Ordering -> Int) -> a -> Int
log2With_

-- | Returns the floor of the base-2 logarithm of a 'Real' argument.
log2Floor :: (HasCallStack, Real a) => a -> Int
log2Floor :: a -> Int
log2Floor = (Int -> Ordering -> Int) -> a -> Int
forall a.
(HasCallStack, Real a) =>
(Int -> Ordering -> Int) -> a -> Int
log2With_ (\Int
x -> \case Ordering
LT -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1; Ordering
_ -> Int
x)

-- | Returns the ceiling of the base-2 logarithm of a 'Real' argument.
log2Ceiling :: (HasCallStack, Real a) => a -> Int
log2Ceiling :: a -> Int
log2Ceiling = (Int -> Ordering -> Int) -> a -> Int
forall a.
(HasCallStack, Real a) =>
(Int -> Ordering -> Int) -> a -> Int
log2With_ (\Int
x -> \case Ordering
GT -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1; Ordering
_ -> Int
x)

{-# INLINE withPositiveInteger #-}
withPositiveInteger :: HasCallStack => (Integer -> r) -> Integer -> r
withPositiveInteger :: (Integer -> r) -> Integer -> r
withPositiveInteger Integer -> r
f Integer
x = if Integer
x Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
0
  then [Char] -> r
forall a. HasCallStack => [Char] -> a
error [Char]
"withPositiveInteger: x <= 0"
  else Integer -> r
f Integer
x

-- | Returns the floor of the base-2 logarithm of an 'Integral' argument.
{-# INLINABLE ilog2Floor #-}
ilog2Floor :: (HasCallStack, Integral a) => a -> Int
ilog2Floor :: a -> Int
ilog2Floor a
x = (Integer -> Int) -> Integer -> Int
forall r. HasCallStack => (Integer -> r) -> Integer -> r
withPositiveInteger Integer -> Int
integerLog2Floor (a -> Integer
forall a. Integral a => a -> Integer
toInteger a
x)

-- | Returns the ceiling of the base-2 logarithm of an 'Integral' argument.
{-# INLINABLE ilog2Ceiling #-}
ilog2Ceiling :: (HasCallStack, Integral a) => a -> Int
ilog2Ceiling :: a -> Int
ilog2Ceiling = (Int -> Ordering -> Int) -> a -> Int
forall a.
(HasCallStack, Integral a) =>
(Int -> Ordering -> Int) -> a -> Int
ilog2With (\Int
x -> \case Ordering
GT -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1; Ordering
_ -> Int
x)

-- | Returns the approximate base-2 logarithm of an 'Integral' argument.
{-# INLINABLE ilog2Approx #-}
ilog2Approx :: (HasCallStack, Integral a) => a -> (Int, Ordering)
ilog2Approx :: a -> (Int, Ordering)
ilog2Approx a
x = (Integer -> (Int, Ordering)) -> Integer -> (Int, Ordering)
forall r. HasCallStack => (Integer -> r) -> Integer -> r
withPositiveInteger
  (\Integer
xi ->
    let lg :: Int
lg = Integer -> Int
integerLog2Floor Integer
xi
    in  (Int
lg, if Integer
xi Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> Integer
forall a. Bits a => Int -> a
bit Int
lg then Ordering
GT else Ordering
EQ))
  (a -> Integer
forall a. Integral a => a -> Integer
toInteger a
x)

-- | Returns the base-2 logarithm of an 'Integral' with custom rounding.
{-# INLINABLE ilog2With #-}
ilog2With :: (HasCallStack, Integral a) => (Int -> Ordering -> Int) -> a -> Int
ilog2With :: (Int -> Ordering -> Int) -> a -> Int
ilog2With Int -> Ordering -> Int
adj a
x =
  let !(Int
lg, Ordering
cmp) = a -> (Int, Ordering)
forall a. (HasCallStack, Integral a) => a -> (Int, Ordering)
ilog2Approx a
x
  in  Int -> Ordering -> Int
adj Int
lg Ordering
cmp