-- | Internal math implementation.
--
-- ==== Example
-- >>> import AtCoder.Internal.Math
-- >>> powMod 10 60 998244353 -- 10^60 mod 998244353
-- 526662729
--
-- >>> isPrime 998244353
-- True
--
-- >>> isPrime 4
-- False
--
-- >>> invGcd 128 37
-- (1,24)
--
-- >>> 24 * 128 `mod` 37 == 1
-- True
--
-- >>> primitiveRoot 2130706433
-- 3
--
-- >>> floorSumUnsigned 8 12 3 5
-- 6
--
-- @since 1.0.0
module AtCoder.Internal.Math
  ( powMod,
    isPrime,
    invGcd,
    primitiveRoot,
    floorSumUnsigned,
  )
where

import AtCoder.Internal.Assert qualified as ACIA
import AtCoder.Internal.Barrett qualified as ACIBT
import Control.Monad.ST (runST)
import Data.Bits ((.<<.), (.>>.))
import Data.Foldable
import Data.Maybe (fromJust)
import Data.Vector.Generic.Mutable qualified as VGM
import Data.Vector.Unboxed qualified as VU
import Data.Vector.Unboxed.Mutable qualified as VUM
import Data.Word (Word64)
import GHC.Stack (HasCallStack)

-- safeMod :: Int -> Int -> Int
-- safeMod = mod

-- | Returns \(x^n \bmod m\).
--
-- ==== Constraints
-- - \(0 \le n\)
-- - \(1 \le m\)
--
-- ==== Complexity
-- - \(O(\log n)\)
--
-- ==== Example
-- >>> let m = 998244353
-- >>> powMod 10 60 m -- 10^60 mod m
-- 526662729
--
-- @since 1.0.0
{-# INLINE powMod #-}
powMod :: (HasCallStack) => Int -> Int -> Int -> Int
powMod :: HasCallStack => Int -> Int -> Int -> Int
powMod Int
x Int
n0 Int
m0
  | Int
m0 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Int
0
  | Bool
otherwise = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Int) -> Word64 -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Word64 -> Word64 -> Word64
inner Int
n0 Word64
1 (Word64 -> Word64) -> Word64 -> Word64
forall a b. (a -> b) -> a -> b
$ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
x Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
m0)
  where
    !()
_ = HasCallStack => Bool -> String -> ()
Bool -> String -> ()
ACIA.runtimeAssert (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n0 Bool -> Bool -> Bool
&& Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
m0) (String -> ()) -> String -> ()
forall a b. (a -> b) -> a -> b
$ String
"BenchLib.PowMod.powMod: given invalid `n` or `m`: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> String
forall a. Show a => a -> String
show (Int
n0, Int
m0)
    bt :: Barrett
bt = Word64 -> Barrett
ACIBT.new64 (Word64 -> Barrett) -> Word64 -> Barrett
forall a b. (a -> b) -> a -> b
$ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m0
    inner :: Int -> Word64 -> Word64 -> Word64
    inner :: Int -> Word64 -> Word64 -> Word64
inner !Int
n !Word64
r !Word64
y
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Word64
r
      | Bool
otherwise =
          let r' :: Word64
r' = if Int -> Bool
forall a. Integral a => a -> Bool
odd Int
n then Barrett -> Word64 -> Word64 -> Word64
ACIBT.mulMod Barrett
bt Word64
r Word64
y else Word64
r
              y' :: Word64
y' = Barrett -> Word64 -> Word64 -> Word64
ACIBT.mulMod Barrett
bt Word64
y Word64
y
           in Int -> Word64 -> Word64 -> Word64
inner (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.>>. Int
1) Word64
r' Word64
y'

-- | M. Forisek and J. Jancina, Fast Primality Testing for Integers That Fit into a Machine Word
--
-- @since 1.0.0
{-# INLINE isPrime #-}
isPrime :: Int -> Bool
isPrime :: Int -> Bool
isPrime Int
n
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = Bool
False
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 Bool -> Bool -> Bool
|| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
7 Bool -> Bool -> Bool
|| Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
61 = Bool
True
  | Int -> Bool
forall a. Integral a => a -> Bool
even Int
n = Bool
False
  | Bool
otherwise =
      let d :: Int
d = Int -> Int
forall {b}. Integral b => b -> b
innerD (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
          test :: Int -> Bool
test Int
a = Int -> Int -> Bool
inner Int
d (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ HasCallStack => Int -> Int -> Int -> Int
Int -> Int -> Int -> Int
powMod Int
a Int
d Int
n
       in (Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Int -> Bool
test [Int
2, Int
7, Int
61 :: Int]
  where
    innerD :: b -> b
innerD b
d
      | b -> Bool
forall a. Integral a => a -> Bool
even b
d = b -> b
innerD (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ b
d b -> b -> b
forall a. Integral a => a -> a -> a
`div` b
2
      | Bool
otherwise = b
d
    inner :: Int -> Int -> Bool
inner Int
t Int
y
      | Int
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Bool -> Bool -> Bool
|| Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
|| Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Bool -> Bool -> Bool
&& Int -> Bool
forall a. Integral a => a -> Bool
even Int
t
      | Bool
otherwise = Int -> Int -> Bool
inner (Int
t Int -> Int -> Int
forall a. Bits a => a -> Int -> a
.<<. Int
1) (Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
y Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
n)

-- | Returns \((g, x)\) such that \(g = \gcd(a, b), \mathrm{xa} = g(\bmod b), 0 \le x \le b/g\).
--
-- ==== Constraints
-- - \(1 \le b\) (not asserted)
--
-- @since 1.0.0
{-# INLINE invGcd #-}
invGcd :: Int -> Int -> (Int, Int)
invGcd :: Int -> Int -> (Int, Int)
invGcd Int
a0 Int
b
  | Int
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = (Int
b, Int
0)
  | Bool
otherwise = Int -> Int -> Int -> Int -> (Int, Int)
inner Int
b Int
a Int
0 Int
1
  where
    !a :: Int
a = Int
a0 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
b
    -- Contracts:
    -- [1] s - m0 * a = 0 (mod b)
    -- [2] t - m1 * a = 0 (mod b)
    -- [3] s * |m1| + t * |m0| <= b
    inner :: Int -> Int -> Int -> Int -> (Int, Int)
    inner :: Int -> Int -> Int -> Int -> (Int, Int)
inner !Int
s !Int
t !Int
m0 !Int
m1
      | Int
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
          let !m' :: Int
m' = if Int
m0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 then Int
m0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
s else Int
m0
           in (Int
s, Int
m')
      | Bool
otherwise =
          let !u :: Int
u = Int
s Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
t
              !s' :: Int
s' = Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
t Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
u
              !m0' :: Int
m0' = Int
m0 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
m1 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
u
           in Int -> Int -> Int -> Int -> (Int, Int)
inner Int
t Int
s' Int
m1 Int
m0'

-- | Returns primitive root.
--
-- @since 1.0.0
{-# INLINE primitiveRoot #-}
primitiveRoot :: Int -> Int
primitiveRoot :: Int -> Int
primitiveRoot Int
m
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = Int
1
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
167772161 = Int
3
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
469762049 = Int
3
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
754974721 = Int
11
  | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
998244353 = Int
3
  | Bool
otherwise = (forall s. ST s Int) -> Int
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Int) -> Int) -> (forall s. ST s Int) -> Int
forall a b. (a -> b) -> a -> b
$ do
      let divs_ :: Vector Int
divs_ = (forall s. ST s (MVector s Int)) -> Vector Int
forall a. Unbox a => (forall s. ST s (MVector s a)) -> Vector a
VU.create ((forall s. ST s (MVector s Int)) -> Vector Int)
-> (forall s. ST s (MVector s Int)) -> Vector Int
forall a b. (a -> b) -> a -> b
$ do
            MVector s Int
divs <- Int -> Int -> ST s (MVector (PrimState (ST s)) Int)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> a -> m (MVector (PrimState m) a)
VUM.replicate Int
20 (Int
0 :: Int)
            MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s Int
MVector (PrimState (ST s)) Int
divs Int
0 Int
2
            let innerX :: b -> b
innerX b
x
                  | b -> Bool
forall a. Integral a => a -> Bool
even b
x = b -> b
innerX (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ b
x b -> b -> b
forall a. Integral a => a -> a -> a
`div` b
2
                  | Bool
otherwise = b
x
            let inner :: Int -> Int -> Int -> ST s (Int, Int)
inner !Int
i !Int
x !Int
cnt
                  | (Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i :: Word64) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
x = (Int, Int) -> ST s (Int, Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int
x, Int
cnt)
                  | Int
x Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = do
                      MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s Int
MVector (PrimState (ST s)) Int
divs Int
cnt Int
i
                      let loop :: Int -> Int
loop Int
x'
                            | Int
x' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Int
loop (Int
x' Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
i)
                            | Bool
otherwise = Int
x'
                      Int -> Int -> Int -> ST s (Int, Int)
inner (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) (Int -> Int
loop Int
x) (Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                  | Bool
otherwise = Int -> Int -> Int -> ST s (Int, Int)
inner (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) Int
x Int
cnt
            (!Int
x, !Int
cnt) <- Int -> Int -> Int -> ST s (Int, Int)
inner Int
3 (Int -> Int
forall {b}. Integral b => b -> b
innerX ((Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)) Int
1
            !Int
cnt' <- do
              if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
                then do
                  MVector (PrimState (ST s)) Int -> Int -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(HasCallStack, PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
VGM.write MVector s Int
MVector (PrimState (ST s)) Int
divs Int
cnt Int
x
                  Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ST s Int) -> Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ Int
cnt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
                else Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
cnt
            MVector s Int -> ST s (MVector s Int)
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MVector s Int -> ST s (MVector s Int))
-> MVector s Int -> ST s (MVector s Int)
forall a b. (a -> b) -> a -> b
$ Int -> MVector s Int -> MVector s Int
forall a s. Unbox a => Int -> MVector s a -> MVector s a
VUM.take Int
cnt' MVector s Int
divs
      let test :: Int -> Bool
test Int
g = (Int -> Bool) -> Vector Int -> Bool
forall a. Unbox a => (a -> Bool) -> Vector a -> Bool
VU.all (Int -> Int -> Bool
testG Int
g) Vector Int
divs_
          testG :: Int -> Int -> Bool
testG Int
g Int
divsI = HasCallStack => Int -> Int -> Int -> Int
Int -> Int -> Int -> Int
powMod Int
g ((Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
divsI) Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1
      Int -> ST s Int
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> ST s Int) -> (Maybe Int -> Int) -> Maybe Int -> ST s Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Int -> ST s Int) -> Maybe Int -> ST s Int
forall a b. (a -> b) -> a -> b
$ (Int -> Bool) -> [Int] -> Maybe Int
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find Int -> Bool
test [Int
2 ..]

-- | Returns \(\sum\limits_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor\).
--
-- ==== Constraints
-- - \(n \lt 2^{32}\)
-- - \(1 \le m \lt 2^{32}\)
--
-- ==== Complexity
-- - \(O(\log m)\)
--
-- @since 1.0.0
{-# INLINE floorSumUnsigned #-}
floorSumUnsigned :: Int -> Int -> Int -> Int -> Int
floorSumUnsigned :: Int -> Int -> Int -> Int -> Int
floorSumUnsigned = Int -> Int -> Int -> Int -> Int -> Int
forall {t}. Integral t => t -> t -> t -> t -> t -> t
inner Int
0
  where
    inner :: t -> t -> t -> t -> t -> t
inner t
acc t
n t
m t
a t
b
      | t
yMax t -> t -> Bool
forall a. Ord a => a -> a -> Bool
< t
m = t
acc'
      | Bool
otherwise = t -> t -> t -> t -> t -> t
inner t
acc' (t
yMax t -> t -> t
forall a. Integral a => a -> a -> a
`div` t
m) t
a' t
m (t
yMax t -> t -> t
forall a. Integral a => a -> a -> a
`rem` t
m)
      where
        a' :: t
a'
          | t
a t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= t
m = t
a t -> t -> t
forall a. Integral a => a -> a -> a
`rem` t
m
          | Bool
otherwise = t
a
        b' :: t
b'
          | t
b t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= t
m = t
b t -> t -> t
forall a. Integral a => a -> a -> a
`rem` t
m
          | Bool
otherwise = t
b
        da :: t
da
          | t
a t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= t
m = t
n t -> t -> t
forall a. Num a => a -> a -> a
* (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1) t -> t -> t
forall a. Integral a => a -> a -> a
`div` t
2 t -> t -> t
forall a. Num a => a -> a -> a
* (t
a t -> t -> t
forall a. Integral a => a -> a -> a
`div` t
m)
          | Bool
otherwise = t
0
        db :: t
db
          | t
b t -> t -> Bool
forall a. Ord a => a -> a -> Bool
>= t
m = t
n t -> t -> t
forall a. Num a => a -> a -> a
* (t
b t -> t -> t
forall a. Integral a => a -> a -> a
`div` t
m)
          | Bool
otherwise = t
0
        acc' :: t
acc' = t
acc t -> t -> t
forall a. Num a => a -> a -> a
+ t
da t -> t -> t
forall a. Num a => a -> a -> a
+ t
db
        yMax :: t
yMax = t
a' t -> t -> t
forall a. Num a => a -> a -> a
* t
n t -> t -> t
forall a. Num a => a -> a -> a
+ t
b'