-- |
-- Module:      Math.NumberTheory.Primes.Factorisation.Montgomery
-- Copyright:   (c) 2011 Daniel Fischer
-- Licence:     MIT
-- Maintainer:  Daniel Fischer <daniel.is.fischer@googlemail.com>
-- Stability:   Provisional
-- Portability: Non-portable (GHC extensions)
--
-- Factorisation of 'Integer's by the elliptic curve algorithm after Montgomery.
-- The algorithm is explained at
-- <http://programmingpraxis.com/2010/04/23/modern-elliptic-curve-factorization-part-1/>
-- and
-- <http://programmingpraxis.com/2010/04/27/modern-elliptic-curve-factorization-part-2/>
--
-- The implementation is not very optimised, so it is not suitable for factorising numbers
-- with only huge prime divisors. However, factors of 20-25 digits are normally found in
-- acceptable time. The time taken depends, however, strongly on how lucky the curve-picking
-- is. With luck, even large factors can be found in seconds; on the other hand, finding small
-- factors (about 10 digits) can take minutes when the curve-picking is bad.
--
-- Given enough time, the algorithm should be able to factor numbers of 100-120 digits, but it
-- is best suited for numbers of up to 50-60 digits.
{-# LANGUAGE CPP, BangPatterns, MagicHash #-}
{-# OPTIONS_HADDOCK hide #-}
module Math.NumberTheory.Primes.Factorisation.Montgomery
  ( -- *  Complete factorisation functions
    -- ** Functions with input checking
    factorise
  , defaultStdGenFactorisation
    -- ** Functions without input checking
  , factorise'
  , stepFactorisation
  , defaultStdGenFactorisation'
    -- * Partial factorisation
  , smallFactors
  , stdGenFactorisation
  , curveFactorisation
    -- ** Single curve worker
  , montgomeryFactorisation
  , findParms
  ) where

#include "MachDeps.h"

import GHC.Base
import GHC.Word
import Data.Array.Base

import System.Random
import Control.Monad.State.Strict
import Control.Applicative
import Data.Bits
import Data.Maybe

import Math.NumberTheory.Logarithms
import Math.NumberTheory.Logarithms.Internal
import Math.NumberTheory.Primes.Sieve.Eratosthenes
import Math.NumberTheory.Primes.Sieve.Indexing
import Math.NumberTheory.Primes.Testing.Probabilistic
import Math.NumberTheory.Utils

-- | @'factorise' n@ produces the prime factorisation of @n@, including
--   a factor of @(-1)@ if @n < 0@. @'factorise' 0@ is an error and the
--   factorisation of @1@ is empty. Uses a 'StdGen' produced in an arbitrary
--   manner from the bit-pattern of @n@.
factorise :: Integer -> [(Integer,Int)]
factorise n
    | n < 0     = (-1,1):factorise (-n)
    | n == 0    = error "0 has no prime factorisation"
    | n == 1    = []
    | otherwise = factorise' n

-- | Like 'factorise', but without input checking, hence @n > 1@ is required.
factorise' :: Integer -> [(Integer,Int)]
factorise' n = defaultStdGenFactorisation' (mkStdGen $ fromInteger n `xor` 0xdeadbeef) n

-- | @'stepFactorisation'@ is like 'factorise'', except that it doesn't use a
--   pseudo random generator but steps through the curves in order.
--   This strategy turns out to be surprisingly fast, on average it doesn't
--   seem to be slower than the 'StdGen' based variant.
stepFactorisation :: Integer -> [(Integer,Int)]
stepFactorisation n
    = let (sfs,mb) = smallFactors 100000 n
      in sfs ++ case mb of
                  Nothing -> []
                  Just r  -> curveFactorisation (Just 10000000000) bailliePSW
                                                (\m k -> (if k < (m-1) then k else error "Curves exhausted",k+1)) 6 Nothing r

-- | @'defaultStdGenFactorisation'@ first strips off all small prime factors and then,
--   if the factorisation is not complete, proceeds to curve factorisation.
--   For negative numbers, a factor of @-1@ is included, the factorisation of @1@
--   is empty. Since @0@ has no prime factorisation, a zero argument causes
--   an error.
defaultStdGenFactorisation :: StdGen -> Integer -> [(Integer,Int)]
defaultStdGenFactorisation sg n
    | n == 0    = error "0 has no prime factorisation"
    | n < 0     = (-1,1) : defaultStdGenFactorisation sg (-n)
    | n == 1    = []
    | otherwise = defaultStdGenFactorisation' sg n

-- | Like 'defaultStdGenFactorisation', but without input checking, so
--   @n@ must be larger than @1@.
defaultStdGenFactorisation' :: StdGen -> Integer -> [(Integer,Int)]
defaultStdGenFactorisation' sg n
    = let (sfs,mb) = smallFactors 100000 n
      in sfs ++ case mb of
                  Nothing -> []
                  Just m  -> stdGenFactorisation (Just 10000000000) sg Nothing m

----------------------------------------------------------------------------------------------------
--                                    Factorisation wrappers                                      --
----------------------------------------------------------------------------------------------------

-- | A wrapper around 'curveFactorisation' providing a few default arguments.
--   The primality test is 'bailliePSW', the @prng@ function - naturally -
--   'randomR'. This function also requires small prime factors to have been
--   stripped before.
stdGenFactorisation :: Maybe Integer    -- ^ Lower bound for composite divisors
                    -> StdGen           -- ^ Standard PRNG
                    -> Maybe Int        -- ^ Estimated number of digits of smallest prime factor
                    -> Integer          -- ^ The number to factorise
                    -> [(Integer,Int)]  -- ^ List of prime factors and exponents
stdGenFactorisation primeBound sg digits n
    = curveFactorisation primeBound bailliePSW (\m -> randomR (6,m-2)) sg digits n

-- | @'curveFactorisation'@ is the driver for the factorisation. Its performance (and success)
--   can be influenced by passing appropriate arguments. If you know that @n@ has no prime divisors
--   below @b@, any divisor found less than @b*b@ must be prime, thus giving @Just (b*b)@ as the
--   first argument allows skipping the comparatively expensive primality test for those.
--   If @n@ is such that all prime divisors must have a specific easy to test for structure, a
--   custom primality test can improve the performance (normally, it will make very little
--   difference, since @n@ has not many divisors, and many curves have to be tried to find one).
--   More influence has the pseudo random generator (a function @prng@ with @6 <= fst (prng k s) <= k-2@
--   and an initial state for the PRNG) used to generate the curves to try. A lucky choice here can
--   make a huge difference. So, if the default takes too long, try another one; or you can improve your
--   chances for a quick result by running several instances in parallel.
--
--   @'curveFactorisation'@ requires that small prime factors have been stripped before. Also, it is
--   unlikely to succeed if @n@ has more than one (really) large prime factor.
curveFactorisation :: Maybe Integer                 -- ^ Lower bound for composite divisors
                   -> (Integer -> Bool)             -- ^ A primality test
                   -> (Integer -> g -> (Integer,g)) -- ^ A PRNG
                   -> g                             -- ^ Initial PRNG state
                   -> Maybe Int                     -- ^ Estimated number of digits of the smallest prime factor
                   -> Integer                       -- ^ The number to factorise
                   -> [(Integer,Int)]               -- ^ List of prime factors and exponents
curveFactorisation primeBound primeTest prng seed mbdigs n
    | ptest n   = [(n,1)]
    | otherwise = evalState (fact n digits) seed
      where
        digits = fromMaybe 8 mbdigs
        mult 1 xs = xs
        mult j xs = [(p,j*k) | (p,k) <- xs]
        dbl (u,v) = (mult 2 u, mult 2 v)
        ptest = case primeBound of
                  Just bd -> \k -> k <= bd || primeTest k
                  Nothing -> primeTest
        rndR k = state (\gen -> prng k gen)
        fact m digs = do let (b1,b2,ct) = findParms digs
                         (pfs,cfs) <- repFact m b1 b2 ct
                         if null cfs
                           then return pfs
                           else do
                               nfs <- forM cfs $ \(k,j) ->
                                   mult j <$> fact k (if null pfs then digs+4 else digs)
                               return (mergeAll $ pfs:nfs)
        repFact m b1 b2 count
            | count < 0 = return ([],[(m,1)])
            | otherwise = do
                s <- rndR m
                case montgomeryFactorisation m b1 b2 s of
                  Nothing -> repFact m b1 b2 (count-1)
                  Just d  -> do
                      let !cof = m `quot` d
                      case gcd cof d of
                        1 -> do
                            (dp,dc) <- if ptest d
                                         then return ([(d,1)],[])
                                         else repFact d b1 b2 (count-1)
                            (cp,cc) <- if ptest cof
                                         then return ([(cof,1)],[])
                                         else repFact cof b1 b2 (count-1)
                            return (merge dp cp, dc ++ cc)
                        g -> do
                            let d' = d `quot` g
                                c' = cof `quot` g
                            (dp,dc) <- if ptest d'
                                         then return ([(d',1)],[])
                                         else repFact d' b1 b2 (count-1)
                            (cp,cc) <- if ptest c'
                                         then return ([(c',1)],[])
                                         else repFact c' b1 b2 (count-1)
                            (gp,gc) <- if ptest g
                                         then return ([(g,2)],[])
                                         else dbl <$> repFact g b1 b2 (count-1)
                            return  (mergeAll [dp,cp,gp], dc ++ cc ++ gc)

----------------------------------------------------------------------------------------------------
--                                         The workhorse                                          --
----------------------------------------------------------------------------------------------------

-- | @'montgomeryFactorisation' n b1 b2 s@ tries to find a factor of @n@ using the
--   curve and point determined by the seed @s@ (@6 <= s < n-1@), multiplying the
--   point by the least common multiple of all numbers @<= b1@ and all primes
--   between @b1@ and @b2@. The idea is that there's a good chance that the order
--   of the point in the curve over one prime factor divides the multiplier, but the
--   order over another factor doesn't, if @b1@ and @b2@ are appropriately chosen.
--   If they are too small, none of the orders will probably divide the multiplier,
--   if they are too large, all probably will, so they should be chosen to fit
--   the expected size of the smallest factor.
--
--   It is assumed that @n@ has no small prime factors.
--
--   The result is maybe a nontrivial divisor of @n@.
montgomeryFactorisation :: Integer -> Word -> Word -> Integer -> Maybe Integer
montgomeryFactorisation n b1 b2 s = go p5 (list primeStore)
  where
    l2 = wordLog2' b1
    b1i = toInteger b1
    (^~) :: Word -> Int -> Word
    w ^~ i = w ^ i
    (e, p0) = montgomeryData n s
    dbl pt = double n e pt
    dbln 0 !pt = pt
    dbln k pt = dbln (k-1) (dbl pt)
    p2 = dbln l2 p0
#if WORD_SIZE_IN_BITS == 64
    mul a b c = (a*b) `quot` c       -- can't overflow, work on Int
#else
    mul a b c = fromInteger ((toInteger a * b) `quot` c) -- might overflow if Int is used
#endif
    adjust bd ml w
      | w <= bd     = adjust bd ml (w*ml)
      | otherwise   = w
    l3 = mul l2 190537 301994
    w3 = 3 ^~ l3
    pw3 = adjust (b1 `quot` 3) 3 w3
    p3 = multiply n e pw3 p2
    l5 = mul l2 1936274 4495889
    w5 = 5 ^~ l5
    pw5 = adjust (b1 `quot` 5) 5 w5
    p5 = multiply n e pw5 p3
    go (P _ 0) _ = Nothing
    go !pt@(P _ z) (pr:prs)
      | pr <= b1    = let !lp = integerLogBase' (fromIntegral pr) b1i
                      in go (multiply n e (pr ^~ lp) pt) prs
      | otherwise   = case gcd n z of
                        1 -> lgo (multiply n e pr pt) prs
                        g -> Just g
    go (P _ z) _    = case gcd n z of
                        1 -> Nothing
                        g -> Just g
    lgo (P _ 0) _ = Nothing
    lgo !pt@(P _ z) (pr:prs)
      | pr <= b2    = lgo (multiply n e pr pt) prs
      | otherwise   = case gcd n z of
                        1 -> Nothing
                        g -> Just g
    lgo (P _ z) _   = case gcd n z of
                        1 -> Nothing
                        g -> Just g

----------------------------------------------------------------------------------------------------
--                            Helpers, Curves and elliptic arithmetics                            --
----------------------------------------------------------------------------------------------------

-- A Montgomery curve is given by y^2 = x^3 + (A_n / A_d - 2)*x^2 + x (mod n).
-- We store A_n and 4*A_d, since A_n occurs with the factor 4 in all formulae.
data Curve = C !Integer !Integer

-- Point in the projective plane, will be on the curve
-- A coordinate transformation eliminates the y-coordinate, hence
-- we store only x and z
data Point = P !Integer !Integer

-- Get curve and point to start
-- Input should satisfy 6 <= s < n-1
montgomeryData :: Integer -> Integer -> (Curve, Point)
montgomeryData n s = (C an ad4, P x z)
  where
    u = (s*s-5) `mod` n
    v = (4*s) `mod` n
    d = (v-u)
    x = (u*u*u) `mod` n
    z = (v*v*v) `mod` n
    an = ((d*d)*(d*(3*u+v))) `mod` n
    ad4 = (16*x*v) `mod` n

-- Addition on the curve, given the modulus n and three points,
-- p0, p1 and p2, with p0 = p2 - p1, calculate the point p1 + p2.
-- Note that the addition does not depend on the curve.
add :: Integer -> Point -> Point -> Point -> Point
add n (P x0 z0) (P x1 z1) (P x2 z2) = P x3 z3
  where
    a = (x1-z1)*(x2+z2)
    b = (x1+z1)*(x2-z2)
    c = a+b
    d = a-b
    x3 = (c*c*z0) `rem` n
    z3 = (d*d*x0) `rem` n

-- Double a point on the curve.
double :: Integer -> Curve -> Point -> Point
double n (C an ad4) (P x z) = P x' z'
  where
    r = x+z
    s = x-z
    u = r*r
    v = s*s
    t = u-v
    x' = (ad4*u*v) `rem` n
    z' = ((ad4*v+t*an)*t) `rem` n

-- Multiply a point on the curve by a Word.
-- Within Word range, we can use the faster variant going
-- from high-order bits to low-order.
multiply :: Integer -> Curve -> Word -> Point -> Point
multiply n cve (W# w##) p =
    case wordLog2# w## of
      l# -> go (l# -# 1#) p (double n cve p)
  where
    go 0# !p0 !p1 = case w## `and#` 1## of
                      0## -> double n cve p0
                      _   -> add n p p0 p1
    go i# p0 p1 = case (uncheckedShiftRL# w## i#) `and#` 1## of
                    0## -> go (i# -# 1#) (double n cve p0) (add n p p0 p1)
                    _   -> go (i# -# 1#) (add n p p0 p1) (double n cve p1)

{-  Not (yet) needed
-- Multiply a point on the curve by an Integer.
multIgr :: Integer -> Curve -> Integer -> Point -> Point
multIgr n cve k p = go k
  where
    go 1 = (p, double n cve p)
    go m = case m `quotRem` 2 of
             (q,r) -> let !(!s, l) = go q
                      in case r of
                           0 -> (double n cve s, add n p s l)
                           _ -> (add n p s l, double n cve l)
-}

-- primes, compactly stored as a bit sieve
primeStore :: [PrimeSieve]
primeStore = psieveFrom 7

-- generate list of primes from arrays
list :: [PrimeSieve] -> [Word]
list sieves = concat [[off + toPrim i | i <- [0 .. li], unsafeAt bs i]
                                | PS vO bs <- sieves, let { (_,li) = bounds bs; off = fromInteger vO; }]

-- | @'smallFactors' bound n@ finds all prime divisors of @n > 1@ up to @bound@ by trial division and returns the
--   list of these together with their multiplicities, and a possible remaining factor which may be composite.
smallFactors :: Integer -> Integer -> ([(Integer,Int)], Maybe Integer)
smallFactors bd n = case shiftToOddCount n of
                      (0,m) -> go m prms
                      (k,m) -> (2,k) <: if m == 1 then ([],Nothing) else go m prms
  where
    prms = tail (primeStore >>= primeList)
    x <: ~(l,b) = (x:l,b)
    go m (p:ps)
        | m < p*p   = ([(m,1)], Nothing)
        | bd < p    = ([], Just m)
        | otherwise = case splitOff p m of
                        (0,_) -> go m ps
                        (k,r) | r == 1 -> ([(p,k)], Nothing)
                              | otherwise -> (p,k) <: go r ps
    go m [] = ([(m,1)], Nothing)

-- helpers: merge sorted lists
merge :: [(Integer,Int)] -> [(Integer,Int)] -> [(Integer,Int)]
merge xxs@(x@(p,k):xs) yys@(y@(q,m):ys) = case compare p q of
                                            LT -> x : merge xs yys
                                            EQ -> (p,k+m) : merge xs ys
                                            GT -> y : merge xxs ys
merge xs [] = xs
merge _ ys = ys

mergeAll :: [[(Integer,Int)]] -> [(Integer,Int)]
mergeAll [] = []
mergeAll [xs] = xs
mergeAll (xs:ys:zss) = merge (merge xs ys) (mergeAll zss)

-- Parameters for the factorisation, the two b-parameters for montgomery and the number of tries
-- to use these, depending on the size of the factor we are looking for.
-- The numbers are roughly based on the parameters listed on Dario Alpern's ECM site.
testParms :: [(Int,Word,Word,Int)]
testParms = [ (12, 400, 10000, 10), (15, 2000, 50000, 25), (20, 11000, 150000, 90)
            , (25, 50000, 500000, 300), (30, 250000, 1500000, 700)
            , (35, 1000000, 4000000, 1800), (40, 3000000, 12000000, 5100)
            , (45, 11000000, 45000000, 10600), (50, 43000000, 200000000, 19300)
            , (55, 80000000, 400000000,30000), (60, 120000000, 700000000, 50000)
            ]

findParms :: Int -> (Word, Word, Int)
findParms digs = go (100, 1000, 7) testParms
  where
    go p ((d,b1,b2,ct):rest)
      | digs < d    = p
      | otherwise   = go (b1,b2,ct) rest
    go p [] = p