-- |
-- Module:      Math.NumberTheory.Utils
-- Copyright:   (c) 2011 Daniel Fischer
-- Licence:     MIT
-- Maintainer:  Daniel Fischer <daniel.is.fischer@googlemail.com>
--
-- Some utilities, mostly for bit twiddling.
--

{-# LANGUAGE BangPatterns   #-}
{-# LANGUAGE MagicHash      #-}
{-# LANGUAGE UnboxedTuples  #-}
{-# LANGUAGE RankNTypes     #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds      #-}
{-# LANGUAGE GADTs          #-}

module Math.NumberTheory.Utils
    ( SomeKnown(..)
    , shiftToOddCount
    , shiftToOdd
    , shiftToOdd#
    , shiftToOddCount#
    , shiftToOddCountBigNat
    , splitOff
    , splitOff#

    , mergeBy

    , recipMod

    , toWheel30
    , fromWheel30
    , withSomeKnown
    , intVal
    ) where

import Prelude hiding (mod, quotRem)
import qualified Prelude as P

import Data.Bits
import Data.Euclidean
import Data.List.Infinite (Infinite(..))
import Data.Semiring (Semiring(..), isZero)
import GHC.Base
import GHC.Num.BigNat
import GHC.Num.Integer
import GHC.Num.Natural
import qualified Math.NumberTheory.Utils.FromIntegral as UT
import GHC.Natural
import GHC.TypeNats
import Math.NumberTheory.Utils.FromIntegral (intToWord)

-- | Remove factors of @2@ and count them. If
--   @n = 2^k*m@ with @m@ odd, the result is @(k, m)@.
--   Precondition: argument not @0@ (not checked).
{-# RULES
"shiftToOddCount/Int"       shiftToOddCount = shiftOCInt
"shiftToOddCount/Word"      shiftToOddCount = shiftOCWord
"shiftToOddCount/Integer"   shiftToOddCount = shiftOCInteger
"shiftToOddCount/Natural"   shiftToOddCount = shiftOCNatural
  #-}
{-# INLINE [1] shiftToOddCount #-}
shiftToOddCount :: Integral a => a -> (Word, a)
shiftToOddCount :: forall a. Integral a => a -> (Word, a)
shiftToOddCount a
n = case Integer -> (Word, Integer)
shiftOCInteger (forall a. Integral a => a -> Integer
toInteger a
n) of
                      (Word
z, Integer
o) -> (Word
z, forall a. Num a => Integer -> a
fromInteger Integer
o)

-- | Specialised version for @'Word'@.
--   Precondition: argument strictly positive (not checked).
shiftOCWord :: Word -> (Word, Word)
shiftOCWord :: Word -> (Word, Word)
shiftOCWord (W# Word#
w#) = case Word# -> (# Word#, Word# #)
shiftToOddCount# Word#
w# of
                        (# Word#
z# , Word#
u# #) -> (Word# -> Word
W# Word#
z#, Word# -> Word
W# Word#
u#)

-- | Specialised version for @'Int'@.
--   Precondition: argument nonzero (not checked).
shiftOCInt :: Int -> (Word, Int)
shiftOCInt :: Int -> (Word, Int)
shiftOCInt (I# Int#
i#) = case Word# -> (# Word#, Word# #)
shiftToOddCount# (Int# -> Word#
int2Word# Int#
i#) of
                        (# Word#
z#, Word#
u# #) -> (Word# -> Word
W# Word#
z#, Int# -> Int
I# (Word# -> Int#
word2Int# Word#
u#))

-- | Specialised version for @'Integer'@.
--   Precondition: argument nonzero (not checked).
shiftOCInteger :: Integer -> (Word, Integer)
shiftOCInteger :: Integer -> (Word, Integer)
shiftOCInteger n :: Integer
n@(IS Int#
i#) =
    case Word# -> (# Word#, Word# #)
shiftToOddCount# (Int# -> Word#
int2Word# Int#
i#) of
      (# Word#
0##, Word#
_ #) -> (Word
0, Integer
n)
      (# Word#
z#, Word#
w# #) -> (Word# -> Word
W# Word#
z#, Word# -> Integer
integerFromWord# Word#
w#)
shiftOCInteger n :: Integer
n@(IP ByteArray#
bn#) = case ByteArray# -> Word#
bigNatZeroCount ByteArray#
bn# of
                                 Word#
0## -> (Word
0, Integer
n)
                                 Word#
z#  -> (Word# -> Word
W# Word#
z#, ByteArray# -> Integer
integerFromBigNat# (ByteArray#
bn# ByteArray# -> Word# -> ByteArray#
`bigNatShiftR#` Word#
z#))
shiftOCInteger n :: Integer
n@(IN ByteArray#
bn#) = case ByteArray# -> Word#
bigNatZeroCount ByteArray#
bn# of
                                 Word#
0## -> (Word
0, Integer
n)
                                 Word#
z#  -> (Word# -> Word
W# Word#
z#, ByteArray# -> Integer
integerFromBigNatNeg# (ByteArray#
bn# ByteArray# -> Word# -> ByteArray#
`bigNatShiftR#` Word#
z#))

-- | Specialised version for @'Natural'@.
--   Precondition: argument nonzero (not checked).
shiftOCNatural :: Natural -> (Word, Natural)
shiftOCNatural :: Natural -> (Word, Natural)
shiftOCNatural n :: Natural
n@(NatS# Word#
i#) =
    case Word# -> (# Word#, Word# #)
shiftToOddCount# Word#
i# of
      (# Word#
0##, Word#
_ #) -> (Word
0, Natural
n)
      (# Word#
z#, Word#
w# #) -> (Word# -> Word
W# Word#
z#, Word# -> Natural
NatS# Word#
w#)
shiftOCNatural n :: Natural
n@(NatJ# (BN# ByteArray#
bn#)) = case ByteArray# -> Word#
bigNatZeroCount ByteArray#
bn# of
                                 Word#
0## -> (Word
0, Natural
n)
                                 Word#
z#  -> (Word# -> Word
W# Word#
z#, ByteArray# -> Natural
naturalFromBigNat# (ByteArray#
bn# ByteArray# -> Word# -> ByteArray#
`bigNatShiftR#` Word#
z#))

shiftToOddCountBigNat :: BigNat# -> (# Word, BigNat# #)
shiftToOddCountBigNat :: ByteArray# -> (# Word, ByteArray# #)
shiftToOddCountBigNat ByteArray#
bn# = case ByteArray# -> Word#
bigNatZeroCount ByteArray#
bn# of
  Word#
0## -> (# Word
0, ByteArray#
bn# #)
  Word#
z#  -> (# Word# -> Word
W# Word#
z#, ByteArray#
bn# ByteArray# -> Word# -> ByteArray#
`bigNatShiftR#` Word#
z# #)

-- | Count trailing zeros in a @'BigNat'@.
--   Precondition: argument nonzero (not checked, Integer invariant).
bigNatZeroCount :: BigNat# -> Word#
bigNatZeroCount :: ByteArray# -> Word#
bigNatZeroCount ByteArray#
bn# = Word# -> Int# -> Word#
count Word#
0## Int#
0#
  where
    !(W# Word#
bitSize#) = Int -> Word
intToWord (forall b. FiniteBits b => b -> Int
finiteBitSize (Word
0 :: Word))
    count :: Word# -> Int# -> Word#
count Word#
a# Int#
i# =
          case ByteArray# -> Int# -> Word#
bigNatIndex# ByteArray#
bn# Int#
i# of
            Word#
0## -> Word# -> Int# -> Word#
count (Word#
a# Word# -> Word# -> Word#
`plusWord#` Word#
bitSize#) (Int#
i# Int# -> Int# -> Int#
+# Int#
1#)
            Word#
w#  -> Word#
a# Word# -> Word# -> Word#
`plusWord#` Word# -> Word#
ctz# Word#
w#

-- | Remove factors of @2@. If @n = 2^k*m@ with @m@ odd, the result is @m@.
--   Precondition: argument not @0@ (not checked).
{-# RULES
"shiftToOdd/Int"       shiftToOdd = shiftOInt
"shiftToOdd/Word"      shiftToOdd = shiftOWord
"shiftToOdd/Integer"   shiftToOdd = shiftOInteger
  #-}
{-# INLINE [1] shiftToOdd #-}
shiftToOdd :: Integral a => a -> a
shiftToOdd :: forall a. Integral a => a -> a
shiftToOdd a
n = forall a. Num a => Integer -> a
fromInteger (Integer -> Integer
shiftOInteger (forall a. Integral a => a -> Integer
toInteger a
n))

-- | Specialised version for @'Int'@.
--   Precondition: argument nonzero (not checked).
shiftOInt :: Int -> Int
shiftOInt :: Int -> Int
shiftOInt (I# Int#
i#) = Int# -> Int
I# (Word# -> Int#
word2Int# (Word# -> Word#
shiftToOdd# (Int# -> Word#
int2Word# Int#
i#)))

-- | Specialised version for @'Word'@.
--   Precondition: argument nonzero (not checked).
shiftOWord :: Word -> Word
shiftOWord :: Word -> Word
shiftOWord (W# Word#
w#) = Word# -> Word
W# (Word# -> Word#
shiftToOdd# Word#
w#)

-- | Specialised version for @'Int'@.
--   Precondition: argument nonzero (not checked).
shiftOInteger :: Integer -> Integer
shiftOInteger :: Integer -> Integer
shiftOInteger (IS Int#
i#) = Word# -> Integer
integerFromWord# (Word# -> Word#
shiftToOdd# (Int# -> Word#
int2Word# Int#
i#))
shiftOInteger n :: Integer
n@(IP ByteArray#
bn#) = case ByteArray# -> Word#
bigNatZeroCount ByteArray#
bn# of
                                 Word#
0## -> Integer
n
                                 Word#
z#  -> ByteArray# -> Integer
integerFromBigNat# (ByteArray#
bn# ByteArray# -> Word# -> ByteArray#
`bigNatShiftR#` Word#
z#)
shiftOInteger n :: Integer
n@(IN ByteArray#
bn#) = case ByteArray# -> Word#
bigNatZeroCount ByteArray#
bn# of
                                 Word#
0## -> Integer
n
                                 Word#
z#  -> ByteArray# -> Integer
integerFromBigNatNeg# (ByteArray#
bn# ByteArray# -> Word# -> ByteArray#
`bigNatShiftR#` Word#
z#)

-- | Shift argument right until the result is odd.
--   Precondition: argument not @0@, not checked.
shiftToOdd# :: Word# -> Word#
shiftToOdd# :: Word# -> Word#
shiftToOdd# Word#
w# = Word# -> Int# -> Word#
uncheckedShiftRL# Word#
w# (Word# -> Int#
word2Int# (Word# -> Word#
ctz# Word#
w#))

-- | Like @'shiftToOdd#'@, but count the number of places to shift too.
shiftToOddCount# :: Word# -> (# Word#, Word# #)
shiftToOddCount# :: Word# -> (# Word#, Word# #)
shiftToOddCount# Word#
w# = case Word# -> Word#
ctz# Word#
w# of
                        Word#
k# -> (# Word#
k#, Word# -> Int# -> Word#
uncheckedShiftRL# Word#
w# (Word# -> Int#
word2Int# Word#
k#) #)

splitOff :: (Eq a, GcdDomain a) => a -> a -> (Word, a)
splitOff :: forall a. (Eq a, GcdDomain a) => a -> a -> (Word, a)
splitOff a
p a
n
  | forall a. (Eq a, Semiring a) => a -> Bool
isZero a
n  = (Word
0, forall a. Semiring a => a
zero) -- prevent infinite loop
  | Bool
otherwise = Word -> a -> (Word, a)
go Word
0 a
n
  where
    go :: Word -> a -> (Word, a)
go !Word
k a
m = case a
m forall a. GcdDomain a => a -> a -> Maybe a
`divide` a
p of
      Just a
q -> Word -> a -> (Word, a)
go (Word
k forall a. Num a => a -> a -> a
+ Word
1) a
q
      Maybe a
_      -> (Word
k, a
m)
{-# INLINABLE splitOff #-}

-- | It is difficult to convince GHC to unbox output of 'splitOff' and 'splitOff.go',
-- so we fallback to a specialized unboxed version to minimize allocations.
splitOff# :: Word# -> Word# -> (# Word#, Word# #)
splitOff# :: Word# -> Word# -> (# Word#, Word# #)
splitOff# Word#
_ Word#
0## = (# Word#
0##, Word#
0## #)
splitOff# Word#
p Word#
n = Word# -> Word# -> (# Word#, Word# #)
go Word#
0## Word#
n
  where
    go :: Word# -> Word# -> (# Word#, Word# #)
go Word#
k Word#
m = case Word#
m Word# -> Word# -> (# Word#, Word# #)
`quotRemWord#` Word#
p of
      (# Word#
q, Word#
0## #) -> Word# -> Word# -> (# Word#, Word# #)
go (Word#
k Word# -> Word# -> Word#
`plusWord#` Word#
1##) Word#
q
      (# Word#, Word# #)
_            -> (# Word#
k, Word#
m #)
{-# INLINABLE splitOff# #-}

-- | Merges two ordered lists into an ordered list. Checks for neither its
-- precondition or postcondition.
mergeBy :: (a -> a -> Ordering) -> Infinite a -> Infinite a -> Infinite a
mergeBy :: forall a.
(a -> a -> Ordering) -> Infinite a -> Infinite a -> Infinite a
mergeBy a -> a -> Ordering
cmp = Infinite a -> Infinite a -> Infinite a
loop
  where
    loop :: Infinite a -> Infinite a -> Infinite a
loop ( a
x:< Infinite a
xs) (a
y :< Infinite a
ys)
      = case a -> a -> Ordering
cmp a
x a
y of
         Ordering
GT -> a
y forall a. a -> Infinite a -> Infinite a
:< Infinite a -> Infinite a -> Infinite a
loop (a
x forall a. a -> Infinite a -> Infinite a
:< Infinite a
xs) Infinite a
ys
         Ordering
_  -> a
x forall a. a -> Infinite a -> Infinite a
:< Infinite a -> Infinite a -> Infinite a
loop Infinite a
xs (a
y forall a. a -> Infinite a -> Infinite a
:< Infinite a
ys)

-- | Work around https://ghc.haskell.org/trac/ghc/ticket/14085
recipMod :: Integer -> Integer -> Maybe Integer
recipMod :: Integer -> Integer -> Maybe Integer
recipMod Integer
x Integer
m = case Integer -> Natural -> (# Natural | () #)
integerRecipMod# (Integer
x forall a. Integral a => a -> a -> a
`P.mod` Integer
m) (forall a. Num a => Integer -> a
fromInteger Integer
m) of
  (# | ()
_ #) -> forall a. Maybe a
Nothing
  (# Natural
y | #) -> forall a. a -> Maybe a
Just (forall a. Integral a => a -> Integer
toInteger Natural
y)

-------------------------------------------------------------------------------
-- Helpers for mapping to rough numbers and back.
-- Copypasted from Data.BitStream.WheelMapping

toWheel30 :: (Integral a, Bits a) => a -> a
toWheel30 :: forall a. (Integral a, Bits a) => a -> a
toWheel30 a
i = a
q forall a. Bits a => a -> Int -> a
`shiftL` Int
3 forall a. Num a => a -> a -> a
+ (a
r forall a. Num a => a -> a -> a
+ a
r forall a. Bits a => a -> Int -> a
`shiftR` Int
4) forall a. Bits a => a -> Int -> a
`shiftR` Int
2
  where
    (a
q, a
r) = a
i forall a. Integral a => a -> a -> (a, a)
`P.quotRem` a
30

fromWheel30 :: (Num a, Bits a) => a -> a
fromWheel30 :: forall a. (Num a, Bits a) => a -> a
fromWheel30 a
i = ((a
i forall a. Bits a => a -> Int -> a
`shiftL` Int
2 forall a. Num a => a -> a -> a
- a
i forall a. Bits a => a -> Int -> a
`shiftR` Int
2) forall a. Bits a => a -> a -> a
.|. a
1)
              forall a. Num a => a -> a -> a
+ ((a
i forall a. Bits a => a -> Int -> a
`shiftL` Int
1 forall a. Num a => a -> a -> a
- a
i forall a. Bits a => a -> Int -> a
`shiftR` Int
1) forall a. Bits a => a -> a -> a
.&. a
2)

-------------------------------------------------------------------------------
-- Helpers for dealing with data types parametrised by natural numbers.

data SomeKnown (f :: Nat -> Type) where
  SomeKnown :: KnownNat k => f k -> SomeKnown f

withSomeKnown :: (forall k. KnownNat k => f k -> a) -> SomeKnown f -> a
withSomeKnown :: forall (f :: Natural -> *) a.
(forall (k :: Natural). KnownNat k => f k -> a) -> SomeKnown f -> a
withSomeKnown forall (k :: Natural). KnownNat k => f k -> a
f (SomeKnown f k
x) = forall (k :: Natural). KnownNat k => f k -> a
f f k
x

intVal :: KnownNat k => a k -> Int
intVal :: forall (k :: Natural) (a :: Natural -> *). KnownNat k => a k -> Int
intVal = Natural -> Int
UT.naturalToInt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Natural) (proxy :: Natural -> *).
KnownNat n =>
proxy n -> Natural
natVal