-- | Small prime fields (up to @p < 2^31@), without type safety.
--
-- This module is considered internal.
--

{-# LANGUAGE BangPatterns #-}
module Math.FiniteField.PrimeField.Small.Raw where

--------------------------------------------------------------------------------

import Data.Bits
import Data.Int
import Data.Word
import GHC.TypeNats (Nat)

-- import Math.FiniteField.Primes

--------------------------------------------------------------------------------

type P = Word64
type F = Word64 

neg :: P -> F -> F
neg :: Word64 -> Word64 -> Word64
neg !Word64
p !Word64
x = if Word64
x forall a. Eq a => a -> a -> Bool
== Word64
0 then Word64
x else (Word64
p forall a. Num a => a -> a -> a
- Word64
x)

add :: P -> F -> F -> F
add :: Word64 -> Word64 -> Word64 -> Word64
add !Word64
p !Word64
x !Word64
y = let a :: Word64
a = Word64
x forall a. Num a => a -> a -> a
+ Word64
y in if Word64
a forall a. Ord a => a -> a -> Bool
< Word64
p then Word64
a else (Word64
a forall a. Num a => a -> a -> a
- Word64
p)

sub :: P -> F -> F -> F
sub :: Word64 -> Word64 -> Word64 -> Word64
sub !Word64
p !Word64
x !Word64
y = if Word64
x forall a. Ord a => a -> a -> Bool
>= Word64
y then (Word64
xforall a. Num a => a -> a -> a
-Word64
y) else (Word64
pforall a. Num a => a -> a -> a
+Word64
xforall a. Num a => a -> a -> a
-Word64
y)

mul :: P -> F -> F -> F
mul :: Word64 -> Word64 -> Word64 -> Word64
mul !Word64
p !Word64
x !Word64
y = forall a. Integral a => a -> a -> a
mod (Word64
xforall a. Num a => a -> a -> a
*Word64
y) Word64
p

--------------------------------------------------------------------------------
-- * Nontrivial operations

pow :: P -> F -> Int64 -> F
pow :: Word64 -> Word64 -> Int64 -> Word64
pow !Word64
p !Word64
z !Int64
e 
  | Word64
z forall a. Eq a => a -> a -> Bool
== Word64
0    = Word64
0
  | Int64
e forall a. Eq a => a -> a -> Bool
== Int64
0    = Word64
1
  | Int64
e forall a. Ord a => a -> a -> Bool
< Int64
0     = Word64 -> Word64 -> Int64 -> Word64
pow Word64
p (Word64 -> Word64 -> Word64
inv Word64
p Word64
z) (forall a. Num a => a -> a
negate Int64
e)
  | Int64
e forall a. Ord a => a -> a -> Bool
>= Int64
pm1i = Word64 -> Word64 -> Int64 -> Word64
go Word64
1 Word64
z (forall a. Integral a => a -> a -> a
mod Int64
e Int64
pm1i)
  | Bool
otherwise = Word64 -> Word64 -> Int64 -> Word64
go Word64
1 Word64
z Int64
e
  where
    pm1 :: Word64
pm1  = Word64
p forall a. Num a => a -> a -> a
- Word64
1
    pm1i :: Int64
pm1i = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
pm1 :: Int64
    go :: F -> F -> Int64 -> F
    go :: Word64 -> Word64 -> Int64 -> Word64
go !Word64
acc !Word64
y !Int64
e = if Int64
e forall a. Eq a => a -> a -> Bool
== Int64
0 
      then Word64
acc
      else case (Int64
e forall a. Bits a => a -> a -> a
.&. Int64
1) of
        Int64
0 -> Word64 -> Word64 -> Int64 -> Word64
go        Word64
acc    (Word64 -> Word64 -> Word64 -> Word64
mul Word64
p Word64
y Word64
y) (forall a. Bits a => a -> Int -> a
shiftR Int64
e Int
1)
        Int64
_ -> Word64 -> Word64 -> Int64 -> Word64
go (Word64 -> Word64 -> Word64 -> Word64
mul Word64
p Word64
acc Word64
y) (Word64 -> Word64 -> Word64 -> Word64
mul Word64
p Word64
y Word64
y) (forall a. Bits a => a -> Int -> a
shiftR Int64
e Int
1)

pow' :: P -> F -> Integer -> F
pow' :: Word64 -> Word64 -> Integer -> Word64
pow' !Word64
p !Word64
z !Integer
e 
  | Word64
z forall a. Eq a => a -> a -> Bool
== Word64
0    = Word64
0
  | Integer
e forall a. Eq a => a -> a -> Bool
== Integer
0    = Word64
1
  | Integer
e forall a. Ord a => a -> a -> Bool
< Integer
0     = Word64 -> Word64 -> Integer -> Word64
pow' Word64
p (Word64 -> Word64 -> Word64
inv Word64
p Word64
z) (forall a. Num a => a -> a
negate Integer
e)
  | Integer
e forall a. Ord a => a -> a -> Bool
>= Integer
pm1  = Word64 -> Word64 -> Int64 -> Word64
pow  Word64
p Word64
z (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Integral a => a -> a -> a
mod Integer
e Integer
pm1))
  | Bool
otherwise = Word64 -> Word64 -> Int64 -> Word64
pow  Word64
p Word64
z (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
e)
  where
    pm1 :: Integer
pm1 = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
p forall a. Num a => a -> a -> a
- Word64
1) :: Integer

-- | Inversion (using Euclid's algorithm)
inv :: P -> F -> F
inv :: Word64 -> Word64 -> Word64
inv !Word64
p !Word64
a 
  | Word64
a forall a. Eq a => a -> a -> Bool
== Word64
0    = Word64
0
  | Bool
otherwise = (Word64 -> Word64 -> Word64 -> Word64 -> Word64 -> Word64
euclid64 Word64
p Word64
1 Word64
0 Word64
a Word64
p) 

-- | Division via Euclid's algorithm
div :: P -> F -> F -> F
div :: Word64 -> Word64 -> Word64 -> Word64
div !Word64
p !Word64
a !Word64
b
  | Word64
b forall a. Eq a => a -> a -> Bool
== Word64
0    = Word64
0
  | Bool
otherwise = (Word64 -> Word64 -> Word64 -> Word64 -> Word64 -> Word64
euclid64 Word64
p Word64
a Word64
0 Word64
b Word64
p) 

-- | Division via multiplying by the inverse
div2 :: P -> F -> F -> F
div2 :: Word64 -> Word64 -> Word64 -> Word64
div2 !Word64
p !Word64
a !Word64
b = Word64 -> Word64 -> Word64 -> Word64
mul Word64
p Word64
a (Word64 -> Word64 -> Word64
inv Word64
p Word64
b)

--------------------------------------------------------------------------------
-- * Euclidean algorithm

-- | Extended binary Euclidean algorithm
euclid64 :: Word64 -> Word64 -> Word64 -> Word64 -> Word64 -> Word64 
euclid64 :: Word64 -> Word64 -> Word64 -> Word64 -> Word64 -> Word64
euclid64 !Word64
p !Word64
x1 !Word64
x2 !Word64
u !Word64
v = Word64 -> Word64 -> Word64 -> Word64 -> Word64
go Word64
x1 Word64
x2 Word64
u Word64
v where

  halfp1 :: Word64
halfp1 = forall a. Bits a => a -> Int -> a
shiftR (Word64
pforall a. Num a => a -> a -> a
+Word64
1) Int
1

  modp :: Word64 -> Word64
  modp :: Word64 -> Word64
modp !Word64
n = forall a. Integral a => a -> a -> a
mod Word64
n Word64
p

  -- Inverse using the binary Euclidean algorithm 
  euclid :: Word64 -> Word64
  euclid :: Word64 -> Word64
euclid Word64
a 
    | Word64
a forall a. Eq a => a -> a -> Bool
== Word64
0     = Word64
0
    | Bool
otherwise  = Word64 -> Word64 -> Word64 -> Word64 -> Word64
go Word64
1 Word64
0 Word64
a Word64
p
  
  go :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
  go :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
go !Word64
x1 !Word64
x2 !Word64
u !Word64
v 
    | Word64
uforall a. Eq a => a -> a -> Bool
==Word64
1       = Word64
x1
    | Word64
vforall a. Eq a => a -> a -> Bool
==Word64
1       = Word64
x2
    | Bool
otherwise  = Word64 -> Word64 -> Word64 -> Word64 -> Word64
stepU Word64
x1 Word64
x2 Word64
u Word64
v

  stepU :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
  stepU :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
stepU !Word64
x1 !Word64
x2 !Word64
u !Word64
v = if forall a. Integral a => a -> Bool
even Word64
u 
    then let u' :: Word64
u'  = forall a. Bits a => a -> Int -> a
shiftR Word64
u Int
1
             x1' :: Word64
x1' = if forall a. Integral a => a -> Bool
even Word64
x1 then forall a. Bits a => a -> Int -> a
shiftR Word64
x1 Int
1 else forall a. Bits a => a -> Int -> a
shiftR Word64
x1 Int
1 forall a. Num a => a -> a -> a
+ Word64
halfp1
         in  Word64 -> Word64 -> Word64 -> Word64 -> Word64
stepU Word64
x1' Word64
x2 Word64
u' Word64
v
    else     Word64 -> Word64 -> Word64 -> Word64 -> Word64
stepV Word64
x1  Word64
x2 Word64
u  Word64
v

  stepV :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
  stepV :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
stepV !Word64
x1 !Word64
x2 !Word64
u !Word64
v = if forall a. Integral a => a -> Bool
even Word64
v
    then let v' :: Word64
v'  = forall a. Bits a => a -> Int -> a
shiftR Word64
v Int
1
             x2' :: Word64
x2' = if forall a. Integral a => a -> Bool
even Word64
x2 then forall a. Bits a => a -> Int -> a
shiftR Word64
x2 Int
1 else forall a. Bits a => a -> Int -> a
shiftR Word64
x2 Int
1 forall a. Num a => a -> a -> a
+ Word64
halfp1
         in  Word64 -> Word64 -> Word64 -> Word64 -> Word64
stepV Word64
x1 Word64
x2' Word64
u Word64
v' 
    else     Word64 -> Word64 -> Word64 -> Word64 -> Word64
final Word64
x1 Word64
x2  Word64
u Word64
v

  final :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
  final :: Word64 -> Word64 -> Word64 -> Word64 -> Word64
final !Word64
x1 !Word64
x2 !Word64
u !Word64
v = if Word64
uforall a. Ord a => a -> a -> Bool
>=Word64
v

    then let u' :: Word64
u'  = Word64
uforall a. Num a => a -> a -> a
-Word64
v
             x1' :: Word64
x1' = if Word64
x1 forall a. Ord a => a -> a -> Bool
>= Word64
x2 then Word64 -> Word64
modp (Word64
x1forall a. Num a => a -> a -> a
-Word64
x2) else Word64 -> Word64
modp (Word64
x1forall a. Num a => a -> a -> a
+Word64
pforall a. Num a => a -> a -> a
-Word64
x2)               
         in  Word64 -> Word64 -> Word64 -> Word64 -> Word64
go Word64
x1' Word64
x2  Word64
u' Word64
v 

    else let v' :: Word64
v'  = Word64
vforall a. Num a => a -> a -> a
-Word64
u
             x2' :: Word64
x2' = if Word64
x2 forall a. Ord a => a -> a -> Bool
>= Word64
x1 then Word64 -> Word64
modp (Word64
x2forall a. Num a => a -> a -> a
-Word64
x1) else Word64 -> Word64
modp (Word64
x2forall a. Num a => a -> a -> a
+Word64
pforall a. Num a => a -> a -> a
-Word64
x1)
         in  Word64 -> Word64 -> Word64 -> Word64 -> Word64
go Word64
x1  Word64
x2' Word64
u  Word64
v'

--------------------------------------------------------------------------------