-- |
-- Module:      Math.NumberTheory.Utils
-- Copyright:   (c) 2011 Daniel Fischer
-- Licence:     MIT
-- Maintainer:  Daniel Fischer <daniel.is.fischer@googlemail.com>
--
-- Some utilities, mostly for bit twiddling.
--

{-# LANGUAGE BangPatterns   #-}
{-# LANGUAGE MagicHash      #-}
{-# LANGUAGE UnboxedTuples  #-}
{-# LANGUAGE RankNTypes     #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds      #-}
{-# LANGUAGE GADTs          #-}

module Math.NumberTheory.Utils
    ( SomeKnown(..)
    , shiftToOddCount
    , shiftToOdd
    , shiftToOdd#
    , shiftToOddCount#
    , shiftToOddCountBigNat
    , splitOff
    , splitOff#

    , mergeBy

    , recipMod

    , toWheel30
    , fromWheel30
    , withSomeKnown
    , intVal
    ) where

import Prelude hiding (mod, quotRem)
import qualified Prelude as P

import Data.Bits
import Data.Euclidean
import Data.Semiring (Semiring(..), isZero)
import GHC.Base
import GHC.Integer.GMP.Internals
import qualified Math.NumberTheory.Utils.FromIntegral as UT
import GHC.Natural
import GHC.TypeNats
import Math.NumberTheory.Utils.FromIntegral (intToWord)

-- | Remove factors of @2@ and count them. If
--   @n = 2^k*m@ with @m@ odd, the result is @(k, m)@.
--   Precondition: argument not @0@ (not checked).
{-# RULES
"shiftToOddCount/Int"       shiftToOddCount = shiftOCInt
"shiftToOddCount/Word"      shiftToOddCount = shiftOCWord
"shiftToOddCount/Integer"   shiftToOddCount = shiftOCInteger
"shiftToOddCount/Natural"   shiftToOddCount = shiftOCNatural
  #-}
{-# INLINE [1] shiftToOddCount #-}
shiftToOddCount :: Integral a => a -> (Word, a)
shiftToOddCount :: a -> (Word, a)
shiftToOddCount a
n = case Integer -> (Word, Integer)
shiftOCInteger (a -> Integer
forall a. Integral a => a -> Integer
toInteger a
n) of
                      (Word
z, Integer
o) -> (Word
z, Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
o)

-- | Specialised version for @'Word'@.
--   Precondition: argument strictly positive (not checked).
shiftOCWord :: Word -> (Word, Word)
shiftOCWord :: Word -> (Word, Word)
shiftOCWord (W# Word#
w#) = case Word# -> (# Word#, Word# #)
shiftToOddCount# Word#
w# of
                        (# Word#
z# , Word#
u# #) -> (Word# -> Word
W# Word#
z#, Word# -> Word
W# Word#
u#)

-- | Specialised version for @'Int'@.
--   Precondition: argument nonzero (not checked).
shiftOCInt :: Int -> (Word, Int)
shiftOCInt :: Int -> (Word, Int)
shiftOCInt (I# Int#
i#) = case Word# -> (# Word#, Word# #)
shiftToOddCount# (Int# -> Word#
int2Word# Int#
i#) of
                        (# Word#
z#, Word#
u# #) -> (Word# -> Word
W# Word#
z#, Int# -> Int
I# (Word# -> Int#
word2Int# Word#
u#))

-- | Specialised version for @'Integer'@.
--   Precondition: argument nonzero (not checked).
shiftOCInteger :: Integer -> (Word, Integer)
shiftOCInteger :: Integer -> (Word, Integer)
shiftOCInteger n :: Integer
n@(S# Int#
i#) =
    case Word# -> (# Word#, Word# #)
shiftToOddCount# (Int# -> Word#
int2Word# Int#
i#) of
      (# Word#
0##, Word#
_ #) -> (Word
0, Integer
n)
      (# Word#
z#, Word#
w# #) -> (Word# -> Word
W# Word#
z#, Word# -> Integer
wordToInteger Word#
w#)
shiftOCInteger n :: Integer
n@(Jp# BigNat
bn#) = case BigNat -> Word#
bigNatZeroCount BigNat
bn# of
                                 Word#
0## -> (Word
0, Integer
n)
                                 Word#
z#  -> (Word# -> Word
W# Word#
z#, BigNat -> Integer
bigNatToInteger (BigNat
bn# BigNat -> Int# -> BigNat
`shiftRBigNat` Word# -> Int#
word2Int# Word#
z#))
shiftOCInteger n :: Integer
n@(Jn# BigNat
bn#) = case BigNat -> Word#
bigNatZeroCount BigNat
bn# of
                                 Word#
0## -> (Word
0, Integer
n)
                                 Word#
z#  -> (Word# -> Word
W# Word#
z#, BigNat -> Integer
bigNatToNegInteger (BigNat
bn# BigNat -> Int# -> BigNat
`shiftRBigNat` Word# -> Int#
word2Int# Word#
z#))

-- | Specialised version for @'Natural'@.
--   Precondition: argument nonzero (not checked).
shiftOCNatural :: Natural -> (Word, Natural)
shiftOCNatural :: Natural -> (Word, Natural)
shiftOCNatural n :: Natural
n@(NatS# Word#
i#) =
    case Word# -> (# Word#, Word# #)
shiftToOddCount# Word#
i# of
      (# Word#
0##, Word#
_ #) -> (Word
0, Natural
n)
      (# Word#
z#, Word#
w# #) -> (Word# -> Word
W# Word#
z#, Word# -> Natural
NatS# Word#
w#)
shiftOCNatural n :: Natural
n@(NatJ# BigNat
bn#) = case BigNat -> Word#
bigNatZeroCount BigNat
bn# of
                                 Word#
0## -> (Word
0, Natural
n)
                                 Word#
z#  -> (Word# -> Word
W# Word#
z#, BigNat -> Natural
bigNatToNatural (BigNat
bn# BigNat -> Int# -> BigNat
`shiftRBigNat` Word# -> Int#
word2Int# Word#
z#))

shiftToOddCountBigNat :: BigNat -> (Word, BigNat)
shiftToOddCountBigNat :: BigNat -> (Word, BigNat)
shiftToOddCountBigNat BigNat
bn# = case BigNat -> Word#
bigNatZeroCount BigNat
bn# of
  Word#
0## -> (Word
0, BigNat
bn#)
  Word#
z#  -> (Word# -> Word
W# Word#
z#, BigNat
bn# BigNat -> Int# -> BigNat
`shiftRBigNat` Word# -> Int#
word2Int# Word#
z#)

-- | Count trailing zeros in a @'BigNat'@.
--   Precondition: argument nonzero (not checked, Integer invariant).
bigNatZeroCount :: BigNat -> Word#
bigNatZeroCount :: BigNat -> Word#
bigNatZeroCount BigNat
bn# = Word# -> Int# -> Word#
count Word#
0## Int#
0#
  where
    !(W# Word#
bitSize#) = Int -> Word
intToWord (Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word))
    count :: Word# -> Int# -> Word#
count Word#
a# Int#
i# =
          case BigNat -> Int# -> Word#
indexBigNat# BigNat
bn# Int#
i# of
            Word#
0## -> Word# -> Int# -> Word#
count (Word#
a# Word# -> Word# -> Word#
`plusWord#` Word#
bitSize#) (Int#
i# Int# -> Int# -> Int#
+# Int#
1#)
            Word#
w#  -> Word#
a# Word# -> Word# -> Word#
`plusWord#` Word# -> Word#
ctz# Word#
w#

-- | Remove factors of @2@. If @n = 2^k*m@ with @m@ odd, the result is @m@.
--   Precondition: argument not @0@ (not checked).
{-# RULES
"shiftToOdd/Int"       shiftToOdd = shiftOInt
"shiftToOdd/Word"      shiftToOdd = shiftOWord
"shiftToOdd/Integer"   shiftToOdd = shiftOInteger
  #-}
{-# INLINE [1] shiftToOdd #-}
shiftToOdd :: Integral a => a -> a
shiftToOdd :: a -> a
shiftToOdd a
n = Integer -> a
forall a. Num a => Integer -> a
fromInteger (Integer -> Integer
shiftOInteger (a -> Integer
forall a. Integral a => a -> Integer
toInteger a
n))

-- | Specialised version for @'Int'@.
--   Precondition: argument nonzero (not checked).
shiftOInt :: Int -> Int
shiftOInt :: Int -> Int
shiftOInt (I# Int#
i#) = Int# -> Int
I# (Word# -> Int#
word2Int# (Word# -> Word#
shiftToOdd# (Int# -> Word#
int2Word# Int#
i#)))

-- | Specialised version for @'Word'@.
--   Precondition: argument nonzero (not checked).
shiftOWord :: Word -> Word
shiftOWord :: Word -> Word
shiftOWord (W# Word#
w#) = Word# -> Word
W# (Word# -> Word#
shiftToOdd# Word#
w#)

-- | Specialised version for @'Int'@.
--   Precondition: argument nonzero (not checked).
shiftOInteger :: Integer -> Integer
shiftOInteger :: Integer -> Integer
shiftOInteger (S# Int#
i#) = Word# -> Integer
wordToInteger (Word# -> Word#
shiftToOdd# (Int# -> Word#
int2Word# Int#
i#))
shiftOInteger n :: Integer
n@(Jp# BigNat
bn#) = case BigNat -> Word#
bigNatZeroCount BigNat
bn# of
                                 Word#
0## -> Integer
n
                                 Word#
z#  -> BigNat -> Integer
bigNatToInteger (BigNat
bn# BigNat -> Int# -> BigNat
`shiftRBigNat` Word# -> Int#
word2Int# Word#
z#)
shiftOInteger n :: Integer
n@(Jn# BigNat
bn#) = case BigNat -> Word#
bigNatZeroCount BigNat
bn# of
                                 Word#
0## -> Integer
n
                                 Word#
z#  -> BigNat -> Integer
bigNatToNegInteger (BigNat
bn# BigNat -> Int# -> BigNat
`shiftRBigNat` Word# -> Int#
word2Int# Word#
z#)

-- | Shift argument right until the result is odd.
--   Precondition: argument not @0@, not checked.
shiftToOdd# :: Word# -> Word#
shiftToOdd# :: Word# -> Word#
shiftToOdd# Word#
w# = Word# -> Int# -> Word#
uncheckedShiftRL# Word#
w# (Word# -> Int#
word2Int# (Word# -> Word#
ctz# Word#
w#))

-- | Like @'shiftToOdd#'@, but count the number of places to shift too.
shiftToOddCount# :: Word# -> (# Word#, Word# #)
shiftToOddCount# :: Word# -> (# Word#, Word# #)
shiftToOddCount# Word#
w# = case Word# -> Word#
ctz# Word#
w# of
                        Word#
k# -> (# Word#
k#, Word# -> Int# -> Word#
uncheckedShiftRL# Word#
w# (Word# -> Int#
word2Int# Word#
k#) #)

splitOff :: (Eq a, GcdDomain a) => a -> a -> (Word, a)
splitOff :: a -> a -> (Word, a)
splitOff a
p a
n
  | a -> Bool
forall a. (Eq a, Semiring a) => a -> Bool
isZero a
n  = (Word
0, a
forall a. Semiring a => a
zero) -- prevent infinite loop
  | Bool
otherwise = Word -> a -> (Word, a)
go Word
0 a
n
  where
    go :: Word -> a -> (Word, a)
go !Word
k a
m = case a
m a -> a -> Maybe a
forall a. GcdDomain a => a -> a -> Maybe a
`divide` a
p of
      Just a
q -> Word -> a -> (Word, a)
go (Word
k Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1) a
q
      Maybe a
_      -> (Word
k, a
m)
{-# INLINABLE splitOff #-}

-- | It is difficult to convince GHC to unbox output of 'splitOff' and 'splitOff.go',
-- so we fallback to a specialized unboxed version to minimize allocations.
splitOff# :: Word# -> Word# -> (# Word#, Word# #)
splitOff# :: Word# -> Word# -> (# Word#, Word# #)
splitOff# Word#
_ Word#
0## = (# Word#
0##, Word#
0## #)
splitOff# Word#
p Word#
n = Word# -> Word# -> (# Word#, Word# #)
go Word#
0## Word#
n
  where
    go :: Word# -> Word# -> (# Word#, Word# #)
go Word#
k Word#
m = case Word#
m Word# -> Word# -> (# Word#, Word# #)
`quotRemWord#` Word#
p of
      (# Word#
q, Word#
0## #) -> Word# -> Word# -> (# Word#, Word# #)
go (Word#
k Word# -> Word# -> Word#
`plusWord#` Word#
1##) Word#
q
      (# Word#, Word# #)
_            -> (# Word#
k, Word#
m #)
{-# INLINABLE splitOff# #-}

-- | Merges two ordered lists into an ordered list. Checks for neither its
-- precondition or postcondition.
mergeBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a]
mergeBy :: (a -> a -> Ordering) -> [a] -> [a] -> [a]
mergeBy a -> a -> Ordering
cmp = [a] -> [a] -> [a]
loop
  where
    loop :: [a] -> [a] -> [a]
loop [] [a]
ys  = [a]
ys
    loop [a]
xs []  = [a]
xs
    loop (a
x:[a]
xs) (a
y:[a]
ys)
      = case a -> a -> Ordering
cmp a
x a
y of
         Ordering
GT -> a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a]
loop (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
xs) [a]
ys
         Ordering
_  -> a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a] -> [a]
loop [a]
xs (a
ya -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
ys)

-- | Work around https://ghc.haskell.org/trac/ghc/ticket/14085
recipMod :: Integer -> Integer -> Maybe Integer
recipMod :: Integer -> Integer -> Maybe Integer
recipMod Integer
x Integer
m = case Integer -> Integer -> Integer
recipModInteger (Integer
x Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`P.mod` Integer
m) Integer
m of
  Integer
0 -> Maybe Integer
forall a. Maybe a
Nothing
  Integer
y -> Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
y

bigNatToNatural :: BigNat -> Natural
bigNatToNatural :: BigNat -> Natural
bigNatToNatural BigNat
bn
  | Int# -> Bool
isTrue# (BigNat -> Int#
sizeofBigNat# BigNat
bn Int# -> Int# -> Int#
==# Int#
1#) = Word# -> Natural
NatS# (BigNat -> Word#
bigNatToWord BigNat
bn)
  | Bool
otherwise = BigNat -> Natural
NatJ# BigNat
bn

-------------------------------------------------------------------------------
-- Helpers for mapping to rough numbers and back.
-- Copypasted from Data.BitStream.WheelMapping

toWheel30 :: (Integral a, Bits a) => a -> a
toWheel30 :: a -> a
toWheel30 a
i = a
q a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
3 a -> a -> a
forall a. Num a => a -> a -> a
+ (a
r a -> a -> a
forall a. Num a => a -> a -> a
+ a
r a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
4) a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
2
  where
    (a
q, a
r) = a
i a -> a -> (a, a)
forall a. Integral a => a -> a -> (a, a)
`P.quotRem` a
30

fromWheel30 :: (Num a, Bits a) => a -> a
fromWheel30 :: a -> a
fromWheel30 a
i = ((a
i a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
2 a -> a -> a
forall a. Num a => a -> a -> a
- a
i a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
2) a -> a -> a
forall a. Bits a => a -> a -> a
.|. a
1)
              a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
i a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
i a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
2)

-------------------------------------------------------------------------------
-- Helpers for dealing with data types parametrised by natural numbers.

data SomeKnown (f :: Nat -> Type) where
  SomeKnown :: KnownNat k => f k -> SomeKnown f

withSomeKnown :: (forall k. KnownNat k => f k -> a) -> SomeKnown f -> a
withSomeKnown :: (forall (k :: Nat). KnownNat k => f k -> a) -> SomeKnown f -> a
withSomeKnown forall (k :: Nat). KnownNat k => f k -> a
f (SomeKnown f k
x) = f k -> a
forall (k :: Nat). KnownNat k => f k -> a
f f k
x

intVal :: KnownNat k => a k -> Int
intVal :: a k -> Int
intVal = Natural -> Int
UT.naturalToInt (Natural -> Int) -> (a k -> Natural) -> a k -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a k -> Natural
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Natural
natVal