-- | Extra math module.
--
-- ==== Examples
-- >>> import AtCoder.Extra.Math qualified as M
-- >>> import Data.Semigroup (Product(..), Sum(..))
-- >>> getProduct $ M.power (<>) 32 (Product 2)
-- 4294967296
--
-- >>> getProduct $ M.stimes' 32 (Product 2)
-- 4294967296
--
-- >>> getProduct $ M.mtimes' 32 (Product 2)
-- 4294967296
--
-- @since 1.0.0
module AtCoder.Extra.Math
  ( -- * Binary exponential
    power,
    stimes',
    mtimes',
  )
where

import Data.Bits ((.>>.))

-- TODO: add `HasCallStack` and provide with `unsafePower`.

-- | Calculates \(s^n\) with custom multiplication operator using the binary exponentiation
-- technique.
--
-- The internal implementation is taken from `Data.Semigroup.stimes`, but `power` uses strict
-- evaluation and is often much faster.
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- ==== Constraints
-- - \(n \gt 0\)
--
-- @since 1.0.0
{-# INLINE power #-}
power :: (a -> a -> a) -> Int -> a -> a
power :: forall a. (a -> a -> a) -> Int -> a -> a
power a -> a -> a
op Int
n0 a
x1
  | Int
n0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Char] -> a
forall a. [Char] -> a
errorWithoutStackTrace [Char]
"AtCoder.Extra.Math.power: positive multiplier expected"
  | Bool
otherwise = a -> Int -> a
f a
x1 Int
n0
  where
    f :: a -> Int -> a
f !a
x !Int
n
      | Int -> Bool
forall a. Integral a => a -> Bool
even Int
n = a -> Int -> a
f (a
x a -> a -> a
`op` a
x) (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1)
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a
x
      | Bool
otherwise = a -> Int -> a -> a
g (a
x a -> a -> a
`op` a
x) (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1) a
x
    g :: a -> Int -> a -> a
g !a
x !Int
n !a
z
      | Int -> Bool
forall a. Integral a => a -> Bool
even Int
n = a -> Int -> a -> a
g (a
x a -> a -> a
`op` a
x) (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1) a
z
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a
x a -> a -> a
`op` a
z
      | Bool
otherwise = a -> Int -> a -> a
g (a
x a -> a -> a
`op` a
x) (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1) (a
x a -> a -> a
`op` a
z)

-- | Strict `Data.Semigroup.stimes`.
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- ==== Constraints
-- - \(n \gt 0\)
--
-- @since 1.0.0
{-# INLINE stimes' #-}
stimes' :: (Semigroup a) => Int -> a -> a
stimes' :: forall a. Semigroup a => Int -> a -> a
stimes' = (a -> a -> a) -> Int -> a -> a
forall a. (a -> a -> a) -> Int -> a -> a
power a -> a -> a
forall a. Semigroup a => a -> a -> a
(<>)

-- | Strict `Data.Monoid.mtimes`.
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- ==== Constraints
-- - \(n \ge 0\)
--
-- @since 1.0.0
{-# INLINE mtimes' #-}
mtimes' :: (Monoid a) => Int -> a -> a
mtimes' :: forall a. Monoid a => Int -> a -> a
mtimes' Int
n a
x = case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
n Int
0 of
  Ordering
LT -> [Char] -> a
forall a. [Char] -> a
errorWithoutStackTrace [Char]
"AtCoder.Extra.Math.mtimes': non-negative multiplier expected"
  Ordering
EQ -> a
forall a. Monoid a => a
mempty
  Ordering
GT -> (a -> a -> a) -> Int -> a -> a
forall a. (a -> a -> a) -> Int -> a -> a
power a -> a -> a
forall a. Semigroup a => a -> a -> a
(<>) Int
n a
x