{-# LANGUAGE BangPatterns, DeriveDataTypeable #-}
{- |
Module      :  Numeric.VariablePrecision.Float
Copyright   :  (c) Claude Heiland-Allen 2012
License     :  BSD3

Maintainer  :  claudiusmaximus@goto10.org
Stability   :  provisional
Portability :  BangPatterns, DeriveDataTypeable

Variable precision software floating point based on @(Integer, Int)@ as
used by 'decodeFloat'.

Accuracy has not been extensively verified, and termination of numerical
algorithms has not been proven.

'floatRange' is arbitrarily limited to mitigate the problems that
occur when enormous integers might be needed during some number
type conversions (worst case consequence: program abort in gmp).

No support for infinities, NaNs, negative zero or denormalization:

  * exponent overflow throws an error instead of resulting in infinity,

  * exponent underflow traces a warning and results in zero instead of
    resulting in a denormalized number.

Some operations throw errors instead of resulting in an infinity or NaN:

  * @'recip' 0@,

  * @x '/' 0@,

  * @'sqrt' x | x < 0@,

  * @'log' x | x <= 0@.

The 'Floating' instance so far only implements algorithms for:

  * 'pi',

  * 'sqrt',

  * 'exp',

  * 'log'

with other 'Floating' methods transitting via 'Double', also 'log'
precision is limited due to internal use of @log 2 :: Double@.

-}
module Numeric.VariablePrecision.Float
  ( VFloat()
  , recodeFloat
  , module Numeric.VariablePrecision.Precision
  , module TypeLevel.NaturalNumber.ExtraNumbers
  ) where

import Data.Data (Data())
import Data.Typeable (Typeable())

import Data.Bits (bit, shiftL, shiftR)
import Data.Monoid (mappend)
import Data.Ratio ((%), numerator, denominator)

import GHC.Float (showSignedFloat)
import Numeric (readSigned, readFloat)
import Text.FShow.RealFloat (DispFloat(), FShow(fshowsPrec), fshowFloat)

import Debug.Trace (trace) -- FIXME

import Numeric.VariablePrecision.Precision
import TypeLevel.NaturalNumber.ExtraNumbers (N24, n24, N53, n53)

-- | A software implementation of floating point arithmetic, using a strict
--   pair of 'Integer' and 'Int', scaled similarly to 'decodeFloat'.
data VFloat p = F !Integer !Int deriving (Data, Typeable)

-- | Convert between generic 'RealFloat' types
--   more efficiently than 'realToFrac'.
recodeFloat :: (RealFloat a, RealFloat b) => a -> b
recodeFloat = uncurry encodeFloat . decodeFloat

instance NaturalNumber p => DispFloat (VFloat p) where
instance NaturalNumber p => FShow (VFloat p) where
  fshowsPrec p = showSignedFloat fshowFloat p
instance NaturalNumber p => Show (VFloat p) where
  showsPrec = fshowsPrec

instance NaturalNumber p => Read (VFloat p) where
  readsPrec _ = readSigned readFloat -- FIXME ignores precedence

instance HasPrecision VFloat
instance VariablePrecision VFloat where
  adjustPrecision (F 0 _) = F 0 0
  adjustPrecision x@(F m e) = result
    where
      result
        | n >  0 = checkVFloat (F (m `shiftL` n) (e - n))
        | n == 0 = checkVFloat (F m e)
        | n <  0 = checkVFloat (F (m `shiftR` negate n) (e + negate n))
      n = nq - np
      np = precision x
      nq = precision result

instance Eq (VFloat p) where
  F 0 _ == F 0 _ = True
  F a b == F x y = a == x && b == y
  F 0 _ /= F 0 _ = False
  F a b /= F x y = a /= x || b /= y

instance Ord (VFloat p) where
  F 0 _ `compare` F x _ = 0 `compare` x
  F a _ `compare` F 0 _ = a `compare` 0
  F a b `compare` F x y
    | a > 0 && x > 0 = (b `compare` y) `mappend` (a `compare` x)
    | a > 0 && x < 0 = GT
    | a < 0 && x > 0 = LT
    | a < 0 && x < 0 = (y `compare` b) `mappend` (a `compare` x)

instance NaturalNumber p => Num (VFloat p) where
  F 0 _ + xy = xy
  ab + F 0 _ = ab
  F a b + F x y
    | b >  y = checkVFloat $ encodeFloat (a + (x `shiftR` (b - y))) b
    | b == y = checkVFloat $ encodeFloat (a + x) b
    | b <  y = checkVFloat $ encodeFloat ((a `shiftR` (y - b)) + x) y
  F 0 _ - xy = checkVFloat $ negate xy
  ab - F 0 _ = checkVFloat $ ab
  F a b - F x y
    | b >  y = checkVFloat $ encodeFloat (a - (x `shiftR` (b - y))) b
    | b == y = checkVFloat $ encodeFloat (a - x) b
    | b <  y = checkVFloat $ encodeFloat ((a `shiftR` (y - b)) - x) y
  ab@(F 0 _) * _ = checkVFloat $ ab
  _ * xy@(F 0 _) = checkVFloat $ xy
  ab@(F a b) * F x y = checkVFloat $ encodeFloat ((a * x) `shiftR` (k - 2)) (b + y + k - 2)
    where k = precision ab
  negate (F a b) = checkVFloat $ F (negate a) b
  abs (F a b) = checkVFloat $ F (abs a) b
  signum (F a _) = fromInteger (signum a)
  fromInteger i = checkVFloat $ encodeFloat i 0

instance NaturalNumber p => Real (VFloat p) where
  toRational (F 0 _) = 0
  toRational (F m e)
    | e >  0 = fromInteger (m `shiftL` e)
    | e == 0 = fromInteger m
    | e <  0 = m % bit (negate e)

instance NaturalNumber p => Fractional (VFloat p) where
  _ / (F 0 _) = error "Numeric.VFloat./0" -- FIMXE
  ab@(F 0 _) / _ = checkVFloat $ ab
  ab@(F a b) / (F x y) = checkVFloat $ encodeFloat ((a `shiftL` (k + 2)) `quot` x) (b - y - k - 2) -- FIXME accuracy
    where k = precision ab
  recip (F 0 _) = error "Numeric.VFloat.recip 0" -- FIXME
  recip xy@(F x y) = checkVFloat $ encodeFloat (bit (2 * k + 2) `quot` x) (negate y - 2 * k - 2) -- FIXME accuracy
    where k = precision xy
  fromRational r = checkVFloat $ fromInteger (numerator r) / fromInteger (denominator r) -- FIXME accuracy

instance NaturalNumber p => RealFrac (VFloat p) where
  properFraction (F 0 _) = (0, checkVFloat $ 0)
  properFraction me@(F m e)
    | e >= 0 = (fromInteger m, checkVFloat $ 0)
    | e < negate (precision me) = (0, checkVFloat $ me)
    | otherwise = (fromInteger n', checkVFloat $ f')
    where
      n = m `shiftR` (negate e)
      d = F (n `shiftL` (negate e)) e
      f = me - d
      (n', f')
        | (m >= 0) == (f >= 0) = (n, f)
        | otherwise = (n + 1, f - 1)

instance NaturalNumber p => RealFloat (VFloat p) where
  floatRadix _ = 2
  floatDigits = precision
  floatRange _ = (negate (bit 20), bit 20) -- FIXME
  -- this floatRange is somewhat arbitrary, but toInteger gives integers
  -- with up to around (precision + maxExponent) bits, the value here
  -- gives rise to potentially more than 300k decimal digits...
  isNaN _ = False
  isInfinite _ = False
  isDenormalized _ = False
  isNegativeZero _ = False
  isIEEE _ = False
  decodeFloat (F 0 _) = (0, 0)
  decodeFloat (F m e) = (m, e)
  encodeFloat 0 _ = F 0 0
  encodeFloat m e = result
    where
      result = checkVFloat $ encodeFloat' (signum m) (abs m) e
      b = precision result
      hi = bit (b + 1)
      lo = bit b
      encodeFloat' !s' !m' !e'
        | m' <= 0 = failed -- FIXME
        | lo <= m' && m' < hi = F (s' * (m' `shiftR` 1)) (e' + 1)
        | m' <  lo = encodeFloat' s' (m' `shiftL` 1) (e' - 1)
        | hi <= m' = encodeFloat' s' (m' `shiftR` 1) (e' + 1)
        | otherwise = failed -- FIXME
        where
          failed = error $ "Numeric.VariablePrecision.VFloat.encodeFloat\n"
                        ++ show (m, e, b, lo, hi, s', m', e')
                        ++ "\nplease report this as a bug."

instance NaturalNumber p => Floating (VFloat p) where -- FIXME

  -- <http://en.wikipedia.org/wiki/AGM_method>
  pi = checkVFloat $ go 1 (sqrt 0.5) 1 2 0
    where
      go a b s k p
        | p == p' = p'
        | otherwise = go a' b' s' k' p'
        where
          a' = (a + b) / 2
          b' = sqrt (a * b)
          c  = (a - b) / 2
          s' = s - k' * c * c
          k' = 2 * k
          p' = 4 * a' * a' / s

  -- Newton's method
  sqrt f
    | 0 == f = F 0 0
    | 0 <  f = checkVFloat $ go 1
    where
      go !r =
        let r' = (r + f / r) / 2
        in  if r == r' then r else go r'

  -- power series
  exp f = checkVFloat $ go 0 1 1 1
    where
      go !e !nf !fn !n =
        let e' = e + fn / nf
        in  if e == e' then e else go e' (nf * n) (f * fn) (n + 1)

  -- <http://en.wikipedia.org/wiki/Logarithm#Arithmetic-geometric_mean_approximation>
  log f@(F _ e)
    | f > 0 = checkVFloat $ pi / (2 * agm 1 (encodeFloat 1 (2 - m) / f)) - fromIntegral m * ln2
    where
      p = precision f
      -- f ~= sqrt 2 * 2^(p + e)
      -- f * 2^m > (sqrt 2) ^ p
      -- sqrt 2 * 2 ^ (p + e) * 2 ^ m > sqrt 2 ^ p
      -- 1/2 + p + e + m > p / 2
      -- 1 + p + 2 e + 2 m > 0
      m = negate $ p `div` 2 + e
      agm !a! b =
        let a' = (a + b) / 2
            b' = sqrt (a * b)
        in  if a' == b' || (a == a' && b == b') then a' else agm a' b'
      ln2 = viaDouble log 2 -- FIXME

  sin = viaDouble sin -- FIXME
  cos = viaDouble cos -- FIXME
  tan = viaDouble tan -- FIXME
  sinh = viaDouble sinh -- FIXME
  cosh = viaDouble cosh -- FIXME
  tanh = viaDouble tanh -- FIXME
  asin = viaDouble asin -- FIXME
  acos = viaDouble acos -- FIXME
  atan = viaDouble atan -- FIXME
  asinh = viaDouble asinh -- FIXME
  acosh = viaDouble acosh -- FIXME
  atanh = viaDouble atanh -- FIXME

viaDouble :: NaturalNumber p => (Double -> Double) -> (VFloat p -> VFloat p)
viaDouble f = recodeFloat . checkDouble . f . recodeFloat

checkDouble :: Double -> Double
checkDouble f
  | isNaN f = error "Numeric.VariablePrecision.Float: isNaN" -- FIXME
  | isInfinite f = error "Numeric.VariablePrecision.Float: isInfinite" -- FIXME
  | otherwise = f

checkVFloat :: NaturalNumber p => VFloat p -> VFloat p
checkVFloat x@(F _ e)
  | lo <= e && e <= hi = x
  | e < lo = trace ("Numeric.VariablePrecision.Float underflow: " ++ show x) 0 -- FIXME
  | otherwise = error ("Numeric.VariablePrecision.Float overflow: " ++ show x) -- FIXME
  where (lo, hi) = floatRange x