module Arithmetic.Montgomery
where
import OpenTheory.Primitive.Natural
import qualified OpenTheory.Natural.Bits as Bits
import OpenTheory.Natural.Divides
import Arithmetic.Utility
data Parameters = Parameters
    {nParameters :: Natural,
     wParameters :: Natural,
     sParameters :: Natural,
     kParameters :: Natural,
     rParameters :: Natural,
     r2Parameters :: Natural,
     zParameters :: Natural}
  deriving Show
data Montgomery = Montgomery
    {pMontgomery :: Parameters,
     nMontgomery :: Natural}
  deriving Show
align :: Natural -> Natural -> Natural
align b n = if n == 0 then 0 else (((n - 1) `div` b) + 1) * b
customParameters :: Natural -> Natural -> Parameters
customParameters n w =
    Parameters
      {nParameters = n,
       wParameters = w,
       sParameters = s,
       kParameters = k,
       rParameters = r,
       r2Parameters = r2,
       zParameters = z}
  where
    w2 = shiftLeft 1 w
    (_,(s,k)) = egcd w2 n
    r = w2 `mod` n
    r2 = (r * r) `mod` n
    z = w2 + n - r
alignedParameters :: Natural -> Natural -> Parameters
alignedParameters b n = customParameters n (align b (Bits.width n))
standardParameters :: Natural -> Parameters
standardParameters = alignedParameters 64
normalize :: Parameters -> Natural -> Montgomery
normalize p =
    loop
  where
    w = wParameters p
    r = rParameters p
    loop a =
        if x == 0 then
          Montgomery
            {pMontgomery = p,
             nMontgomery = a}
        else
          loop ((a - shiftLeft x w) + x * r)
      where
        x = shiftRight a w
normalize1 :: Parameters -> Natural -> Montgomery
normalize1 p a =
    Montgomery {pMontgomery = p, nMontgomery = b}
  where
    n = nParameters p
    w = wParameters p
    b = if Bits.bit a w then a - n else a
reduce :: Parameters -> Natural -> Natural
reduce p a =
    shiftRight (a + Bits.bound (a * k) w * n) w
  where
    n = nParameters p
    w = wParameters p
    k = kParameters p
toNatural :: Montgomery -> Natural
toNatural a =
    if b < n then b else 0
  where
    p = pMontgomery a
    n = nParameters p
    b = reduce p (nMontgomery a)
fromNatural :: Parameters -> Natural -> Montgomery
fromNatural p =
    multiply r2 . normalize p
  where
    r2 = Montgomery {pMontgomery = p, nMontgomery = r2Parameters p}
zero :: Parameters -> Montgomery
zero p = Montgomery {pMontgomery = p, nMontgomery = 0}
one :: Parameters -> Montgomery
one p = Montgomery {pMontgomery = p, nMontgomery = rParameters p}
two :: Parameters -> Montgomery
two p = double (one p)
add :: Montgomery -> Montgomery -> Montgomery
add a b = normalize (pMontgomery a) (nMontgomery a + nMontgomery b)
double :: Montgomery -> Montgomery
double a = add a a
negate :: Montgomery -> Montgomery
negate a =
    normalize1 p (z - nMontgomery a)
  where
    p = pMontgomery a
    z = zParameters p
subtract :: Montgomery -> Montgomery -> Montgomery
subtract a b = add a (Arithmetic.Montgomery.negate b)
multiply :: Montgomery -> Montgomery -> Montgomery
multiply a b =
    normalize1 p (reduce p (nMontgomery a * nMontgomery b))
  where
    p = pMontgomery a
square :: Montgomery -> Montgomery
square a = multiply a a
exp :: Montgomery -> Natural -> Montgomery
exp a =
    multiplyExponential multiply (one p) a
  where
    p = pMontgomery a
exp2 :: Montgomery -> Natural -> Montgomery
exp2 a k = functionPower square k a
modexp :: Natural -> Natural -> Natural -> Natural
modexp n a k =
    toNatural m
  where
    p = standardParameters n
    m = Arithmetic.Montgomery.exp (fromNatural p a) k
modexp2 :: Natural -> Natural -> Natural -> Natural
modexp2 n a k =
    toNatural m
  where
    p = standardParameters n
    m = exp2 (fromNatural p a) k