module PrimeField
( PrimeField
, toInt
) where
import Protolude
import Control.Monad.Random (Random(..), getRandom)
import GHC.Integer.GMP.Internals (powModInteger, recipModInteger)
import Test.Tasty.QuickCheck (Arbitrary(..))
import Text.PrettyPrint.Leijen.Text (Pretty(..))
import GaloisField (GaloisField(..))
newtype PrimeField (p :: Nat) = PF Integer
deriving (Bits, Eq, Generic, NFData, Read, Show)
instance KnownNat p => GaloisField (PrimeField p) where
char = natVal
{-# INLINE char #-}
deg = const 1
{-# INLINE deg #-}
frob = identity
{-# INLINE frob #-}
pow w@(PF x) n = PF (powModInteger x n (natVal w))
{-# INLINE pow #-}
quad = primeQuad
{-# INLINE quad #-}
rnd = getRandom
{-# INLINE rnd #-}
sr w@(PF x) = let p = natVal w
in if p == 2 || x == 0 then Just w else PF <$> primeSqrt p x
{-# INLINE sr #-}
toInt :: PrimeField p -> Integer
toInt (PF x) = x
{-# INLINABLE toInt #-}
instance KnownNat p => Arbitrary (PrimeField p) where
arbitrary = fromInteger <$> arbitrary
instance KnownNat p => Fractional (PrimeField p) where
recip w@(PF x) = PF (if x == 0 then panic "no multiplicative inverse."
else recipModInteger x (natVal w))
{-# INLINE recip #-}
fromRational (x:%y) = fromInteger x / fromInteger y
{-# INLINABLE fromRational #-}
instance KnownNat p => Num (PrimeField p) where
w@(PF x) + PF y = PF (if xyp >= 0 then xyp else xy)
where
xy = x + y
xyp = xy - natVal w
{-# INLINE (+) #-}
w@(PF x) * PF y = PF (rem (x * y) (natVal w))
{-# INLINE (*) #-}
w@(PF x) - PF y = PF (if xy >= 0 then xy else xy + natVal w)
where
xy = x - y
{-# INLINE (-) #-}
negate w@(PF x) = PF (if x == 0 then 0 else -x + natVal w)
{-# INLINE negate #-}
fromInteger x = PF (if y >= 0 then y else y + p)
where
y = rem x p
p = natVal (witness :: PrimeField p)
{-# INLINABLE fromInteger #-}
abs = panic "not implemented."
signum = panic "not implemented."
instance KnownNat p => Pretty (PrimeField p) where
pretty (PF x) = pretty x
instance KnownNat p => Random (PrimeField p) where
random = first PF . randomR (0, natVal (witness :: PrimeField p) - 1)
{-# INLINE random #-}
randomR = panic "not implemented."
isQNR :: Integer -> Integer -> Bool
isQNR p n = powModInteger n (shiftR (p - 1) 1) p /= 1
{-# INLINE isQNR #-}
factor2 :: Integer -> (Integer, Int)
factor2 p = factor 0 (p - 1)
where
factor :: Int -> Integer -> (Integer, Int)
factor s q
| testBit q 0 = (q, s)
| otherwise = factor (s + 1) (shiftR q 1)
{-# INLINE factor2 #-}
getQNR :: Integer -> Integer
getQNR p
| p7 == 3 || p7 == 5 = 2
| otherwise = case find (isQNR p) ps of
Just q -> q
_ -> panic "no quadratic nonresidue."
where
p7 = p .&. 7
ps = 3 : 5 : 7 : 11 : 13 : concatMap (\x -> [x - 1, x + 1]) [18, 24 ..]
{-# INLINE getQNR #-}
primeSqrt :: Integer -> Integer -> Maybe Integer
primeSqrt p n
| isQNR p n = Nothing
| otherwise = min <*> (-) p <$> case (factor2 p, getQNR p) of
((q, s), z) -> let zq = powModInteger z q p
nq = powModInteger n (quot q 2) p
nnq = rem (n * nq) p
in loop s zq (rem (nq * nnq) p) nnq
where
loop :: Int -> Integer -> Integer -> Integer -> Maybe Integer
loop m c t r
| t == 0 = Just 0
| t == 1 = Just r
| otherwise = let i = least t 0
b = powModInteger c (bit (m - i - 1)) p
b2 = rem (b * b) p
in loop i b2 (rem (t * b2) p) (rem (r * b) p)
where
least :: Integer -> Int -> Int
least 1 j = j
least ti j = least (rem (ti * ti) p) (j + 1)
{-# INLINE primeSqrt #-}
primeQuad :: KnownNat p
=> PrimeField p -> PrimeField p -> PrimeField p -> Maybe (PrimeField p)
primeQuad a b c
| a == 0 = Nothing
| p == 2 = if c == 0 then Just 0 else if b == 0 then Just 1 else Nothing
| otherwise = (/ (2 * a)) . subtract b <$> sr (b * b - 4 * a * c)
where
p = char a :: Integer
{-# INLINE primeQuad #-}