{- | module: Arithmetic.Montgomery description: Modular arithmetic using Montgomery multiplication license: MIT maintainer: Joe Leslie-Hurd stability: provisional portability: portable -} module Arithmetic.Montgomery where import OpenTheory.Primitive.Natural import qualified OpenTheory.Natural.Bits as Bits import OpenTheory.Natural.Divides import qualified Arithmetic.Modular as Modular data Parameters = Parameters {nParameters :: Natural, wParameters :: Natural, sParameters :: Natural, kParameters :: Natural, rParameters :: Natural, r2Parameters :: Natural, zParameters :: Natural} deriving Show data Montgomery = Montgomery {pMontgomery :: Parameters, nMontgomery :: Natural} deriving Show align :: Natural -> Natural -> Natural align b n = if n == 0 then 0 else (((n - 1) `div` b) + 1) * b customParameters :: Natural -> Natural -> Parameters customParameters n w = Parameters {nParameters = n, wParameters = w, sParameters = s, kParameters = k, rParameters = r, r2Parameters = r2, zParameters = z} where w2 = shiftLeft 1 w (_,(s,k)) = egcd w2 n r = w2 `mod` n r2 = (r * r) `mod` n z = w2 + n - r alignedParameters :: Natural -> Natural -> Parameters alignedParameters b n = customParameters n (align b (Bits.width n)) standardParameters :: Natural -> Parameters standardParameters = alignedParameters 64 -- normalize p a `mod` n = a `mod` n -- normalize p a < 2 ^ w normalize :: Parameters -> Natural -> Montgomery normalize p = loop where w = wParameters p r = rParameters p loop a = if x == 0 then Montgomery {pMontgomery = p, nMontgomery = a} else loop ((a - shiftLeft x w) + x * r) where x = shiftRight a w -- normalize1 p a `mod` n = a `mod` n -- a < 2 ^ w + n ==> normalize1 p a < 2 ^ w normalize1 :: Parameters -> Natural -> Montgomery normalize1 p a = Montgomery {pMontgomery = p, nMontgomery = b} where n = nParameters p w = wParameters p b = if Bits.bit a w then a - n else a -- reduce p a `mod` n = (a * s) `mod` n -- a <= r * x ==> reduce p a < x + n reduce :: Parameters -> Natural -> Natural reduce p a = shiftRight (a + Bits.bound (a * k) w * n) w where n = nParameters p w = wParameters p k = kParameters p toNatural :: Montgomery -> Natural toNatural a = if b < n then b else 0 where p = pMontgomery a n = nParameters p b = reduce p (nMontgomery a) fromNatural :: Parameters -> Natural -> Montgomery fromNatural p = multiply r2 . normalize p where r2 = Montgomery {pMontgomery = p, nMontgomery = r2Parameters p} zero :: Parameters -> Montgomery zero p = Montgomery {pMontgomery = p, nMontgomery = 0} one :: Parameters -> Montgomery one p = Montgomery {pMontgomery = p, nMontgomery = rParameters p} two :: Parameters -> Montgomery two p = double (one p) add :: Montgomery -> Montgomery -> Montgomery add a b = normalize (pMontgomery a) (nMontgomery a + nMontgomery b) double :: Montgomery -> Montgomery double a = add a a negate :: Montgomery -> Montgomery negate a = normalize1 p (z - nMontgomery a) where p = pMontgomery a z = zParameters p subtract :: Montgomery -> Montgomery -> Montgomery subtract a b = add a (Arithmetic.Montgomery.negate b) multiply :: Montgomery -> Montgomery -> Montgomery multiply a b = normalize1 p (reduce p (nMontgomery a * nMontgomery b)) where p = pMontgomery a square :: Montgomery -> Montgomery square a = multiply a a exp :: Montgomery -> Natural -> Montgomery exp a = Modular.multiplyExponential multiply (one p) a where p = pMontgomery a exp2 :: Montgomery -> Natural -> Montgomery exp2 a k = Modular.functionPower square k a modexp :: Natural -> Natural -> Natural -> Natural modexp n a k = toNatural m where p = standardParameters n m = Arithmetic.Montgomery.exp (fromNatural p a) k modexp2 :: Natural -> Natural -> Natural -> Natural modexp2 n a k = toNatural m where p = standardParameters n m = exp2 (fromNatural p a) k