{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UnboxedTuples #-}
module Data.Mod.Word
( Mod
, unMod
, invertMod
, (^%)
) where
import Prelude as P hiding (even)
import Control.Exception
import Control.DeepSeq
import Data.Bits
#ifdef MIN_VERSION_semirings
import Data.Euclidean (GcdDomain(..), Euclidean(..), Field)
import Data.Ratio
import Data.Semiring (Semiring(..), Ring(..))
#endif
#ifdef MIN_VERSION_vector
import Data.Primitive (Prim)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as M
import qualified Data.Vector.Primitive as P
import qualified Data.Vector.Unboxed as U
#endif
import Foreign.Storable (Storable)
import GHC.Exts
import GHC.Generics
import GHC.Integer.GMP.Internals
import GHC.Natural (Natural(..))
import GHC.TypeNats (Nat, KnownNat, natVal)
newtype Mod (m :: Nat) = Mod
{ unMod :: Word
}
#ifdef MIN_VERSION_vector
deriving (Eq, Ord, Generic, Storable, Prim)
#else
deriving (Eq, Ord, Generic, Storable)
#endif
instance NFData (Mod m)
instance KnownNat m => Show (Mod m) where
show m = "(" ++ show (unMod m) ++ " `modulo` " ++ show (natVal m) ++ ")"
instance KnownNat m => Enum (Mod m) where
succ x = if x == maxBound then throw Overflow else coerce (succ @Word) x
pred x = if x == minBound then throw Underflow else coerce (pred @Word) x
toEnum = fromIntegral
fromEnum = fromIntegral . unMod
enumFrom x = enumFromTo x maxBound
enumFromThen x y = enumFromThenTo x y (if y >= x then maxBound else minBound)
enumFromTo = coerce (enumFromTo @Word)
enumFromThenTo = coerce (enumFromThenTo @Word)
instance KnownNat m => Bounded (Mod m) where
minBound = Mod 0
maxBound = let mx = Mod (fromIntegral (natVal mx) - 1) in mx
#if !MIN_VERSION_base(4,12,0)
addWordC# :: Word# -> Word# -> (# Word#, Int# #)
addWordC# x# y# = (# z#, word2Int# c# #)
where
!(# c#, z# #) = x# `plusWord2#` y#
#endif
addMod :: Natural -> Word -> Word -> Word
addMod (NatS# m#) (W# x#) (W# y#) =
if isTrue# c# || isTrue# (z# `geWord#` m#) then W# (z# `minusWord#` m#) else W# z#
where
!(# z#, c# #) = x# `addWordC#` y#
addMod NatJ#{} _ _ = tooLargeModulo
subMod :: Natural -> Word -> Word -> Word
subMod (NatS# m#) (W# x#) (W# y#) =
if isTrue# (x# `geWord#` y#) then W# z# else W# (z# `plusWord#` m#)
where
z# = x# `minusWord#` y#
subMod NatJ#{} _ _ = tooLargeModulo
negateMod :: Natural -> Word -> Word
negateMod _ (W# 0##) = W# 0##
negateMod (NatS# m#) (W# x#) = W# (m# `minusWord#` x#)
negateMod NatJ#{} _ = tooLargeModulo
mulMod :: Natural -> Word -> Word -> Word
mulMod (NatS# m#) (W# x#) (W# y#) = W# r#
where
!(# z1#, z2# #) = timesWord2# x# y#
!(# _, r# #) = quotRemWord2# z1# z2# m#
mulMod NatJ#{} _ _ = tooLargeModulo
fromIntegerMod :: Natural -> Integer -> Word
fromIntegerMod (NatS# 0##) !_ = throw DivideByZero
fromIntegerMod (NatS# m#) (S# x#) =
if isTrue# (x# >=# 0#)
then W# (int2Word# x# `remWord#` m#)
else negateMod (NatS# m#) (W# (int2Word# (negateInt# x#) `remWord#` m#))
fromIntegerMod (NatS# m#) (Jp# x#) =
W# (x# `remBigNatWord` m#)
fromIntegerMod (NatS# m#) (Jn# x#) =
negateMod (NatS# m#) (W# (x# `remBigNatWord` m#))
fromIntegerMod NatJ#{} _ = tooLargeModulo
#ifdef MIN_VERSION_semirings
fromNaturalMod :: Natural -> Natural -> Word
fromNaturalMod (NatS# 0##) !_ = throw DivideByZero
fromNaturalMod (NatS# m#) (NatS# x#) = W# (x# `remWord#` m#)
fromNaturalMod (NatS# m#) (NatJ# x#) = W# (x# `remBigNatWord` m#)
fromNaturalMod NatJ#{} _ = tooLargeModulo
#endif
tooLargeModulo :: a
tooLargeModulo = error "modulo does not fit into a machine word"
instance KnownNat m => Num (Mod m) where
mx@(Mod !x) + (Mod !y) = Mod $ addMod (natVal mx) x y
{-# INLINE (+) #-}
mx@(Mod !x) - (Mod !y) = Mod $ subMod (natVal mx) x y
{-# INLINE (-) #-}
negate mx@(Mod !x) = Mod $ negateMod (natVal mx) x
{-# INLINE negate #-}
mx@(Mod !x) * (Mod !y) = Mod $ mulMod (natVal mx) x y
{-# INLINE (*) #-}
abs = id
{-# INLINE abs #-}
signum = const x
where
x = if natVal x > 1 then Mod 1 else Mod 0
{-# INLINE signum #-}
fromInteger x = mx
where
mx = Mod $ fromIntegerMod (natVal mx) x
{-# INLINE fromInteger #-}
#ifdef MIN_VERSION_semirings
instance KnownNat m => Semiring (Mod m) where
plus = (+)
{-# INLINE plus #-}
times = (*)
{-# INLINE times #-}
zero = Mod 0
{-# INLINE zero #-}
one = mx
where
mx = if natVal mx > 1 then Mod 1 else Mod 0
{-# INLINE one #-}
fromNatural x = mx
where
mx = Mod $ fromNaturalMod (natVal mx) x
{-# INLINE fromNatural #-}
instance KnownNat m => Ring (Mod m) where
negate = P.negate
{-# INLINE negate #-}
instance KnownNat m => Fractional (Mod m) where
fromRational r = case denominator r of
1 -> num
den -> num / fromInteger den
where
num = fromInteger (numerator r)
{-# INLINE fromRational #-}
recip mx = case invertMod mx of
Nothing -> throw DivideByZero
Just y -> y
{-# INLINE recip #-}
instance KnownNat m => GcdDomain (Mod m) where
divide x y = Just (x / y)
gcd = const $ const 1
lcm = const $ const 1
coprime = const $ const True
instance KnownNat m => Euclidean (Mod m) where
degree = const 0
quotRem x y = (x / y, 0)
quot = (/)
rem = const $ const 0
instance KnownNat m => Field (Mod m)
#endif
invertMod :: KnownNat m => Mod m -> Maybe (Mod m)
invertMod mx@(Mod x) = case natVal mx of
NatJ#{} -> tooLargeModulo
NatS# 0## -> Nothing
NatS# m# -> Mod <$> invertModWord x (W# m#)
invertModWord :: Word -> Word -> Maybe Word
invertModWord x m@(W# m#)
| even x, isTrue# (k# `gtWord#` 0##) = Nothing
| otherwise = case invertModWordOdd x m' of
Nothing -> Nothing
Just y -> Just $ goDouble y (1 - x * y)
where
k# = ctz# m#
m' = m `unsafeShiftR` I# (word2Int# k#)
xm' = x * m'
goDouble :: Word -> Word -> Word
goDouble acc r@(W# r#)
| isTrue# (tz# `geWord#` k#)
= acc
| otherwise
= goDouble (acc + m' `unsafeShiftL` tz) (r - xm' `unsafeShiftL` tz)
where
tz# = ctz# r#
tz = I# (word2Int# tz#)
invertModWordOdd :: Word -> Word -> Maybe Word
invertModWordOdd 0 !_ = Nothing
invertModWordOdd !x !m = go00 0 m 1 x
where
halfMp1 :: Word
halfMp1 = half m + 1
go00 :: Word -> Word -> Word -> Word -> Maybe Word
go00 !r !s !r' !s'
| even s = let (# hr, hs #) = doHalf r s in go00 hr hs r' s'
| otherwise = go10 r s r' s'
go10 :: Word -> Word -> Word -> Word -> Maybe Word
go10 !r !s !r' !s'
| even s' = let (# hr', hs' #) = doHalf r' s' in go10 r s hr' hs'
| otherwise = go11 r s r' s'
go01 :: Word -> Word -> Word -> Word -> Maybe Word
go01 !r !s !r' !s'
| even s = let (# hr, hs #) = doHalf r s in go01 hr hs r' s'
| otherwise = go11 r s r' s'
go11 :: Word -> Word -> Word -> Word -> Maybe Word
go11 !r !s !r' !s' = case s `compare` s' of
EQ -> if s == 1 then Just r else Nothing
LT -> let newR' = r' - r + (r `ge` r') * m in
let newS' = s' - s in
let (# hr', hs' #) = doHalf newR' newS' in
go10 r s hr' hs'
GT -> let newR = r - r' + (r' `ge` r) * m in
let newS = s - s' in
let (# hr, hs #) = doHalf newR newS in
go01 hr hs r' s'
doHalf :: Word -> Word -> (# Word, Word #)
doHalf r s = (# half r + (r .&. 1) * halfMp1, half s #)
{-# INLINE doHalf #-}
ge :: Word -> Word -> Word
ge (W# x) (W# y) = W# (int2Word# (x `geWord#` y))
even :: Word -> Bool
even x = (x .&. 1) == 0
{-# INLINE even #-}
half :: Word -> Word
half x = x `shiftR` 1
{-# INLINE half #-}
(^%) :: (KnownNat m, Integral a) => Mod m -> a -> Mod m
mx@(Mod (W# x#)) ^% a = case natVal mx of
NatJ#{} -> tooLargeModulo
NatS# m#
| a < 0 -> case invertMod mx of
Nothing -> throw DivideByZero
Just (Mod (W# y#)) -> Mod $ W# (f y# (- a) 1##)
| otherwise -> Mod $ W# (f x# a 1##)
where
f :: Integral a => Word# -> a -> Word# -> Word#
f _ 0 acc# = acc#
f b# e acc# = f bb# (e `P.quot` 2) (if odd e then ba# else acc#)
where
!(# bb1#, bb2# #) = timesWord2# b# b#
!(# _, bb# #) = quotRemWord2# bb1# bb2# m#
!(# ba1#, ba2# #) = timesWord2# b# acc#
!(# _, ba# #) = quotRemWord2# ba1# ba2# m#
{-# INLINABLE [1] (^%) #-}
{-# SPECIALISE [1] (^%) ::
KnownNat m => Mod m -> Integer -> Mod m,
KnownNat m => Mod m -> Natural -> Mod m,
KnownNat m => Mod m -> Int -> Mod m,
KnownNat m => Mod m -> Word -> Mod m #-}
{-# RULES
"powMod" forall (x :: KnownNat m => Mod m) p. x ^ p = x ^% p
"powMod/2/Integer" forall x. x ^% (2 :: Integer) = let u = x in u*u
"powMod/3/Integer" forall x. x ^% (3 :: Integer) = let u = x in u*u*u
"powMod/2/Int" forall x. x ^% (2 :: Int) = let u = x in u*u
"powMod/3/Int" forall x. x ^% (3 :: Int) = let u = x in u*u*u
"powMod/2/Word" forall x. x ^% (2 :: Word) = let u = x in u*u
"powMod/3/Word" forall x. x ^% (3 :: Word) = let u = x in u*u*u #-}
infixr 8 ^%
#ifdef MIN_VERSION_vector
newtype instance U.MVector s (Mod m) = MV_Mod (P.MVector s Word)
newtype instance U.Vector (Mod m) = V_Mod (P.Vector Word)
instance U.Unbox (Mod m)
instance M.MVector U.MVector (Mod m) where
{-# INLINE basicLength #-}
{-# INLINE basicUnsafeSlice #-}
{-# INLINE basicOverlaps #-}
{-# INLINE basicUnsafeNew #-}
{-# INLINE basicInitialize #-}
{-# INLINE basicUnsafeReplicate #-}
{-# INLINE basicUnsafeRead #-}
{-# INLINE basicUnsafeWrite #-}
{-# INLINE basicClear #-}
{-# INLINE basicSet #-}
{-# INLINE basicUnsafeCopy #-}
{-# INLINE basicUnsafeGrow #-}
basicLength (MV_Mod v) = M.basicLength v
basicUnsafeSlice i n (MV_Mod v) = MV_Mod $ M.basicUnsafeSlice i n v
basicOverlaps (MV_Mod v1) (MV_Mod v2) = M.basicOverlaps v1 v2
basicUnsafeNew n = MV_Mod <$> M.basicUnsafeNew n
basicInitialize (MV_Mod v) = M.basicInitialize v
basicUnsafeReplicate n x = MV_Mod <$> M.basicUnsafeReplicate n (unMod x)
basicUnsafeRead (MV_Mod v) i = Mod <$> M.basicUnsafeRead v i
basicUnsafeWrite (MV_Mod v) i x = M.basicUnsafeWrite v i (unMod x)
basicClear (MV_Mod v) = M.basicClear v
basicSet (MV_Mod v) x = M.basicSet v (unMod x)
basicUnsafeCopy (MV_Mod v1) (MV_Mod v2) = M.basicUnsafeCopy v1 v2
basicUnsafeMove (MV_Mod v1) (MV_Mod v2) = M.basicUnsafeMove v1 v2
basicUnsafeGrow (MV_Mod v) n = MV_Mod <$> M.basicUnsafeGrow v n
instance G.Vector U.Vector (Mod m) where
{-# INLINE basicUnsafeFreeze #-}
{-# INLINE basicUnsafeThaw #-}
{-# INLINE basicLength #-}
{-# INLINE basicUnsafeSlice #-}
{-# INLINE basicUnsafeIndexM #-}
{-# INLINE elemseq #-}
basicUnsafeFreeze (MV_Mod v) = V_Mod <$> G.basicUnsafeFreeze v
basicUnsafeThaw (V_Mod v) = MV_Mod <$> G.basicUnsafeThaw v
basicLength (V_Mod v) = G.basicLength v
basicUnsafeSlice i n (V_Mod v) = V_Mod $ G.basicUnsafeSlice i n v
basicUnsafeIndexM (V_Mod v) i = Mod <$> G.basicUnsafeIndexM v i
basicUnsafeCopy (MV_Mod mv) (V_Mod v) = G.basicUnsafeCopy mv v
elemseq _ = seq
#endif