```-- |
-- Module:      Math.NumberTheory.Moduli.Multiplicative
-- Copyright:   (c) 2017 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Multiplicative groups of integers modulo m.
--

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE PatternSynonyms #-}

module Math.NumberTheory.Moduli.Multiplicative
( -- * Multiplicative group
MultMod
, multElement
, isMultElement
, invertGroup
-- * Primitive roots
, PrimitiveRoot
, unPrimitiveRoot
, isPrimitiveRoot
, discreteLogarithm
) where

import Data.Constraint
import Data.Mod
import Data.Semigroup
import GHC.TypeNats (KnownNat, natVal)
import Numeric.Natural

import Math.NumberTheory.Moduli.Internal
import Math.NumberTheory.Moduli.Singleton
import Math.NumberTheory.Primes

-- | This type represents elements of the multiplicative group mod m, i.e.
-- those elements which are coprime to m. Use @toMultElement@ to construct.
newtype MultMod m = MultMod {
multElement :: Mod m -- ^ Unwrap a residue.
} deriving (Eq, Ord, Show)

instance KnownNat m => Semigroup (MultMod m) where
MultMod a <> MultMod b = MultMod (a * b)
stimes k a@(MultMod a')
| k >= 0 = MultMod (a' ^% k)
| otherwise = invertGroup \$ stimes (-k) a
-- ^ This Semigroup is in fact a group, so @stimes@ can be called with a negative first argument.

instance KnownNat m => Monoid (MultMod m) where
mempty = MultMod 1
mappend = (<>)

instance KnownNat m => Bounded (MultMod m) where
minBound = MultMod 1
maxBound = MultMod (-1)

-- | Attempt to construct a multiplicative group element.
isMultElement :: KnownNat m => Mod m -> Maybe (MultMod m)
isMultElement a = if unMod a `gcd` natVal a == 1
then Just \$ MultMod a
else Nothing

-- | For elements of the multiplicative group, we can safely perform the inverse
-- without needing to worry about failure.
invertGroup :: KnownNat m => MultMod m -> MultMod m
invertGroup (MultMod a) = case invertMod a of
Just b -> MultMod b
Nothing -> error "Math.NumberTheory.Moduli.invertGroup: failed to invert element"

-- | 'PrimitiveRoot' m is a type which is only inhabited
-- by <https://en.wikipedia.org/wiki/Primitive_root_modulo_n primitive roots> of m.
newtype PrimitiveRoot m = PrimitiveRoot
{ unPrimitiveRoot :: MultMod m -- ^ Extract primitive root value.
}
deriving (Eq, Show)

-- | Check whether a given modular residue is
-- a <https://en.wikipedia.org/wiki/Primitive_root_modulo_n primitive root>.
--
-- >>> :set -XDataKinds
-- >>> import Data.Maybe
-- >>> isPrimitiveRoot (fromJust cyclicGroup) (1 :: Mod 13)
-- Nothing
-- >>> isPrimitiveRoot (fromJust cyclicGroup) (2 :: Mod 13)
-- Just (PrimitiveRoot {unPrimitiveRoot = MultMod {multElement = (2 `modulo` 13)}})
isPrimitiveRoot
:: (Integral a, UniqueFactorisation a)
=> CyclicGroup a m
-> Mod m
-> Maybe (PrimitiveRoot m)
isPrimitiveRoot cg r = case proofFromCyclicGroup cg of
Sub Dict -> do
r' <- isMultElement r
guard \$ isPrimitiveRoot' cg (fromIntegral (unMod r))
return \$ PrimitiveRoot r'

-- | Computes the discrete logarithm. Currently uses a combination of the baby-step
-- giant-step method and Pollard's rho algorithm, with Bach reduction.
--
-- >>> :set -XDataKinds
-- >>> import Data.Maybe
-- >>> let cg = fromJust cyclicGroup :: CyclicGroup Integer 13
-- >>> let rt = fromJust (isPrimitiveRoot cg 2)
-- >>> let x  = fromJust (isMultElement 11)
-- >>> discreteLogarithm cg rt x
-- 7
discreteLogarithm :: CyclicGroup Integer m -> PrimitiveRoot m -> MultMod m -> Natural
discreteLogarithm cg (multElement . unPrimitiveRoot -> a) (multElement -> b) = case cg of
CG2
-> 0
-- the only valid input was a=1, b=1
CG4
-> if unMod b == 1 then 0 else 1
-- the only possible input here is a=3 with b = 1 or 3
CGOddPrimePower (unPrime -> p) k
-> discreteLogarithmPP p k (toInteger (unMod a)) (toInteger (unMod b))
CGDoubleOddPrimePower (unPrime -> p) k
-> discreteLogarithmPP p k (toInteger (unMod a) `rem` p^k) (toInteger (unMod b) `rem` p^k)
-- we have the isomorphism t -> t `rem` p^k from (Z/2p^kZ)* -> (Z/p^kZ)*
```