{-# LANGUAGE BangPatterns #-}
{- |
Module      :  Numeric.VariablePrecision.Algorithms
Copyright   :  (c) Claude Heiland-Allen 2012
License     :  BSD3

Maintainer  :  claude@mathr.co.uk
Stability   :  unstable
Portability :  BangPatterns

Implementations of various floating point algorithms.  Accuracy has not
been extensively verified, and termination has not been proven.

Everything assumes that 'floatRadix' is 2.  This is *not* checked.

Functions taking an @accuracy@ parameter may fail to terminate if
@accuracy@ is too small.  Accuracy is measured in least significant
bits, similarly to '(=~=)'.

In this documentation, /basic functionality/ denotes that methods used
are from classes:

  * 'Num', 'Eq', 'Ord'.

Further, /basic RealFloat functionality/ denotes /basic functionality/
with the addition of:

  * Anything in 'RealFloat' except for 'atan2'.

The intention behind the used functionality documentation is to help
users decide when it is appropriate to use these generic implementations
to implement instances.

-}
module Numeric.VariablePrecision.Algorithms
  ( recodeFloat
  , viaDouble
  , (=~=)
  , genericRecip
  , genericSqrt
  , genericExp
  , genericLog
  , genericLog'
  , genericLog2
  , genericLog''
  , genericPi
  , genericSin
  , genericPositiveZero
  , genericNegativeZero
  , genericPositiveInfinity
  , genericNegativeInfinity
  , genericNotANumber
  , sameSign
  ) where

import Data.Bits (bit, shiftR)
import Data.List (foldl')

-- | Special values implemented using basic RealFloat functionality.
genericPositiveZero, genericNegativeZero, genericPositiveInfinity, genericNegativeInfinity, genericNotANumber :: RealFloat a => a

genericPositiveZero =  0

genericNegativeZero = -0

genericPositiveInfinity = result
  where
    result = encodeFloat m e
    m = bit (floatDigits (undefined `asTypeOf` result))
    e = snd (floatRange  (undefined `asTypeOf` result))

genericNegativeInfinity = result
  where
    result = encodeFloat (negate m) e
    m = bit (floatDigits (undefined `asTypeOf` result))
    e = snd (floatRange  (undefined `asTypeOf` result))

genericNotANumber = genericPositiveInfinity + genericNegativeInfinity


-- | Convert between generic 'RealFloat' types more efficiently than
--   'realToFrac'.  Tries hard to preserve special values like
--   infinities and negative zero, but any NaN payload is lost.
--
--   Uses only basic RealFloat functionality.
--
recodeFloat :: (RealFloat a, RealFloat b) => a -> b
recodeFloat !x
  | isNaN x               = genericNotANumber
  | isInfinite x && x > 0 = genericPositiveInfinity
  | isInfinite x && x < 0 = genericNegativeInfinity
  | isNegativeZero x      = genericNegativeZero
  | x == 0                = genericPositiveZero
  | otherwise = uncurry encodeFloat (decodeFloat x)


-- | Check if two numbers have the same sign.
--   May give a nonsense result if an argument is NaN.
sameSign :: (Ord a, Num a) => a -> a -> Bool
sameSign a b = compare 0 a == compare 0 b


-- | Approximate equality.
--   @(a =~= b) c@ when adding the difference to the larger in magnitude
--   changes at most @c@ least significant mantissa bits.
--
--   Uses only basic RealFloat functionality.
--
(=~=) :: RealFloat a => a -> a -> Int -> Bool
(=~=) !x !y !s
  | x == y = True
  | isNaN x && isNaN y = True
  | isNaN x || isNaN y = False
  | isInfinite x || isInfinite y = False
  | not (sameSign a b) = False
  | otherwise = abs (e - f) <= s && abs (x - y) <= encodeFloat 1 (s + (e `max` f))
  where
    (a, e) = decodeFloat x
    (b, f) = decodeFloat y


-- | Compute a reciprocal using the Newton-Raphson division algorithm,
--   as described in
--   <http://en.wikipedia.org/wiki/Division_%28digital%29#Newton.E2.80.93Raphson_division>.
--
--   Uses only basic RealFloat functionality.
--
genericRecip :: RealFloat a => Int {- ^ accuracy -} -> a -> a
genericRecip accuracy y = recip' y
  where
    recip' f0
      | isNaN f0 = f0
      | isInfinite f0 && f0 > 0 = genericPositiveZero
      | isInfinite f0 && f0 < 0 = genericNegativeZero
      | isNegativeZero f0       = genericNegativeInfinity
      | f0 == 0                 = genericPositiveInfinity
      | f0 <  0 = negate . recip' . negate $ f0
      | otherwise = scaleFloat sh (go d s0 x0)
      where
        x0 = k48 - k32 * d
        d = significand f0  -- in [0.5,1)
        sh = exponent d - exponent f0
    go !d !s !x
      | (x =~= x') accuracy = x'
      | s == 0 = x'
      | otherwise = go d (s - 1) x'
      where
        x' = scaleFloat 1 x - d * x * x  -- x * (2 - d * x)
    -- an attempt to avoid recomputing per-type constants
    p = floatDigits (undefined `asTypeOf` y)
    s0 = ceiling (logBase 2 (fromIntegral (p + 1) / logBase 2 17) :: Double) :: Int
    k48 = recodeFloat (48/17 :: Double)
    k32 = recodeFloat (32/17 :: Double)


-- | Compute a square root using Newton's method.
--
--   Uses basic RealFloat functionality and '(/)'.
--
genericSqrt :: RealFloat a => Int {- ^ accuracy -} -> a -> a
genericSqrt accuracy f0
  | f0 < 0 = genericNotANumber
  | f0 == 0 = f0  -- preserves negative zero
  | isNaN f0 = f0
  | isInfinite f0 = f0
  | otherwise = go (viaDouble sqrt f)
  where
    e = exponent f0
    d = if even e then 2 else 1
    s = e - d  -- even
    f = scaleFloat (negate s) f0  -- in [1,4)
    go !r =
      let r' = scaleFloat (-1) (r + f / r)
      in  if (r =~= r') accuracy then scaleFloat (s `shiftR` 1) r' else go r'


-- | Compute an exponential using power series.
--
--   Uses basic RealFloat functionality, '(/)' and 'recip'.
--
genericExp :: RealFloat a => Int {-^ accuracy -} -> a -> a
genericExp accuracy x
  | isNaN x = x
  | isInfinite x && x < 0 = 0
  | isInfinite x = x
  | x == 0 = 1
  | x <  0 = recip . genericExp accuracy . negate $ x
  | otherwise = go 0 1 1
  where
    go !s !xnnf{- x^n / n! -} !n
      | (s =~= s') accuracy = s'
      | otherwise  = go s' (xnnf * x / fromIntegral n) (n + 1 :: Int)
      where
        s' = s + xnnf


-- | Compute a logarithm.
--
--   See 'genericLog''' for algorithmic references.
--
--   Uses basic RealFloat functionality, 'sqrt' and 'recip'.
--
genericLog :: RealFloat a => Int {- ^ accuracy -} -> a -> a
genericLog accuracy = genericLog' accuracy (genericLog2 accuracy)


-- | Compute log 2.
--
--   See 'genericLog''' for algorithmic references.
--
--   Uses basic RealFloat functionality, 'sqrt' and 'recip'.
--
genericLog2 :: RealFloat a => Int {- ^ accuracy -} -> a
genericLog2 accuracy = negate (genericLog'' accuracy 0.5)


-- | Compute a logarithm using decomposition and a value for @log 2@.
--
--   See 'genericLog''' for algorithmic references.
--
--   Uses basic RealFloat functionality, 'sqrt', and 'recip'.
--
genericLog' :: RealFloat a => Int {- ^ accuracy -} -> a {- ^ log 2 -} -> a -> a
genericLog' accuracy ln2 x
  | isNaN x      = x
  | x == 0       = genericNegativeInfinity
  | x <  0       = genericNotANumber
  | isInfinite x = x
  | otherwise    = mln2 + genericLog'' accuracy s
  where
    m = exponent    x
    s = significand x
    mln2 -- micro-optimisation
      | m == 0 = 0
      | otherwise = fromIntegral m * ln2


-- | Compute a logarithm for a value in [0.5,1) using the AGM method
--   as described in section 7 of
--   /The Logarithmic Constant: log 2/
--   Xavier Gourdon and Pascal Sebah, May 18, 2010,
--   <http://numbers.computation.free.fr/Constants/Log2/log2.ps>.
--
--   The precondition is not checked.
--
--   Uses basic RealFloat functionality, 'sqrt', and 'recip'.
--
genericLog'' :: RealFloat a => Int {- ^ accuracy -} -> a {- ^ value in [0.5,1) -} -> a
genericLog'' accuracy x = result
  where
    result = go (-1) 1 (encodeFloat 1 m) 0 1 (scaleFloat m x) 0
    m2 = accuracy - floatDigits (undefined `asTypeOf` result)
    m = m2 `shiftR` 1
    small y = y == 0 || exponent y <= m2
    go !n !a !b !s !c !d !t
      | small ds && small dt = recip (1 - s') - recip (1 - t')
      | otherwise = go n' a' b' s' c' d' t'
      where
        a' = scaleFloat (-1) (a + b)
        c' = scaleFloat (-1) (c + d)
        b' = sqrt (a * b)
        d' = sqrt (c * d)
        ds = scaleFloat n (a * a - b * b)
        dt = scaleFloat n (c * c - d * d)
        t' = t + dt
        s' = s + ds
        n' = n + 1


-- | Compute pi using the method described in section 8 of
--   /Multiple-precision zero-finding methods and the complexity of elementary function evaluation/
--   Richard P Brent, 1975 (revised May 30, 2010),
--   <http://arxiv.org/abs/1004.3412>.
--
--   Uses basic RealFloat functionality, '(/)', and 'sqrt'.
--
genericPi :: RealFloat a => Int {- ^ accuracy -} -> a
--   Works ok up to around 600,000 bits (178,000 decimal digits) but after
--   that further increase to mantissa precision leads to problems.
--   Output compared against /Pi/ by Scott Hemphill <http://www.gutenberg.org/ebooks/50>.
genericPi accuracy = result
  where
    sqr x = x * x
    result = go 1 (sqrt 0.5) 0.25 0 1
    go !a !b !t !k !p
      | (p =~= p') accuracy = p'
      | otherwise = go a' b' t' k' p'
      where
        a' = scaleFloat (-1) (a + b)
        b' = sqrt (a * b)
        t' = t - scaleFloat k (sqr (a' - a))
        k' = k + 1
        p' = scaleFloat (-2) (sqr (a + b) / t)


-- | Compute 'sin' using the method described in section 3 of
--   /Efficient multiple-precision evaluation of elementary functions/
--   David M Smith, 1989,
--   <http://digitalcommons.lmu.edu/math_fac/1/>
--
--   Requires a value for pi.
--
--   Uses basic RealFloat functionality, '(/)', and sqrt.
genericSin :: RealFloat a => Int {-^ accuracy -} -> a {-^ pi -} -> a {-^ x -} -> a
genericSin accuracy pi x0 = reduced taylor x0
  where
    sqr y = y * y
    t :: Double
    t = fromIntegral (floatDigits x0 - accuracy)
    k :: Int
    k = round (t / 3)
    three = 3 ^ k
    up 0 !y = y
    up n !y = up (n - 1) (y * (3 - scaleFloat 2 (sqr y)))
    reduced f y
      | y == 0      = y
      | y <  0      = (negate . reduced f . negate) y
      | y <= pi / 4 = (up k . f . (/ three)) y
      | y <= pi / 2 = (sqrt . (1 -) . sqr . reduced f . (pi / 2 -)) y
      | y <= pi     = (reduced f . (pi -)) y
      | y <= pi * 2 = (negate . reduced f . (pi * 2 -)) y
      | otherwise = (reduced f . subtract (pi * 2)) y
    taylor y = (sum' . reverse . takeWhile ((> threshold) . abs) . go 1) y
      where
        threshold = scaleFloat (negate (floatDigits y * 2)) y
        x2 = sqr y
        go !n !xnf = xnf : go n' xnf'
          where
            n' = n + 1
            xnf' = negate (x2 * xnf / fromInt (2 * n + 1))
    fromInt :: Num a => Int -> a
    fromInt = fromIntegral

sum' :: Num a => [a] -> a
sum' = foldl' (+) 0

-- | Lift a function from Double to generic 'RealFloat' types.
viaDouble :: (RealFloat a, RealFloat b) => (Double -> Double) -> a -> b
viaDouble f = recodeFloat . f . recodeFloat


-- FIXME everything assumes that floatRadix is 2 without checking