{-# LANGUAGE BangPatterns, DeriveDataTypeable, Rank2Types #-}
{- |
Module      :  Numeric.VariablePrecision.Float
Copyright   :  (c) Claude Heiland-Allen 2012
License     :  BSD3

Maintainer  :  claudiusmaximus@goto10.org
Stability   :  unstable
Portability :  BangPatterns, DeriveDataTypeable, Rank2Types

Variable precision software floating point based on @(Integer, Int)@ as
used by 'decodeFloat'.  Supports infinities and NaN, but not negative
zero or denormalization.

Accuracy has not been extensively verified, and termination of numerical
algorithms has not been proven.

-}
module Numeric.VariablePrecision.Float
  ( VFloat()
  , Normed(norm1, norm2, norm2Squared, normInfinity)
  , effectivePrecisionWith
  , effectivePrecision
  , (-@?)
  , DFloat(..)
  , toDFloat
  , fromDFloat
  , withDFloat
  ) where

import Data.Data (Data())
import Data.Typeable (Typeable())

import Data.Bits (bit, shiftL, shiftR)
import Data.Ratio ((%), numerator, denominator)

import GHC.Float (showSignedFloat)
import Numeric (readSigned, readFloat)
import Text.FShow.RealFloat (DispFloat(), FShow(fshowsPrec), fshowFloat)

import Numeric.VariablePrecision.Algorithms
import Numeric.VariablePrecision.Precision
import Numeric.VariablePrecision.Precision.Reify
import Numeric.VariablePrecision.Integer.Logarithm


-- | A software implementation of floating point arithmetic, using a strict
--   pair of 'Integer' and 'Int', scaled similarly to 'decodeFloat', along
--   with additional values representing:
--
--     * positive infinity (@1/0@),
--
--     * negative infinity (@-1/0@),
--
--     * not a number (@0/0@).
--
--   The 'Floating' instance so far only implements algorithms for:
--
--     * 'pi',
--
--     * 'sqrt',
--
--     * 'exp',
--
--     * 'log'.
--
--   These 'Floating' methods transit via 'Double' and so have limited
--   precision:
--
--     * 'sin', 'cos', 'tan',
--
--     * 'asin', 'acos', 'atan',
--
--     * 'sinh', 'cosh', 'tanh',
--
--     * 'asinh', 'acosh', 'atanh'.
--
--   'floatRange' is arbitrarily limited to mitigate the problems that
--   occur when enormous integers might be needed during some number
--   type conversions (worst case consequence: program abort in gmp).
--
data VFloat p
  = F !Integer !Int
    -- invariant: matches decodeFloat spec
    -- if unsure, use encodeVFloat which maintains the invariant
    -- if sure, use checkVFloat which checks the invariant
    -- only construct with bare F when absolutely sure
  | FZero   -- FIXME add negative zero
  | FPosInf
  | FNegInf
  | FNaN    -- FIXME add payload
  deriving (Data, Typeable)

encodeVFloat :: NaturalNumber p => VFloat p -> Integer -> Int -> VFloat p
encodeVFloat witness = self
    where
      b = fromIntegral $ precision (undefined `asTypeOf` witness)
      b' = b - 1
      self 0 !_ = FZero
      self m  e = checkVFloat "encodeFloat'" $ encodeFloat'' (m > 0) m' (e - sh) l
        where
          absm = abs m
          m' = absm `shift` sh
          e2 = integerLog2 absm
          sh = b - e2
          l = integerLog2 m'
      encodeFloat'' !s' !m' !e' !l
        | m' <= 0 = failed -- FIXME
        | b' == l = F (if s' then m' else negate m') e'
        | b' <  l = {-# SCC "encodeFloat''.shiftR" #-} encodeFloat'' s' (m' `shiftR` 1) (e' + 1) (l - 1)
        | b' >  l = {-# SCC "encodeFloat''.shiftL" #-} encodeFloat'' s' (m' `shiftL` 1) (e' - 1) (l + 1)
        | otherwise = failed -- FIXME
        where
          failed = error $ "Numeric.VariablePrecision.Float.encodeVFloat: internal error (please report this bug): "
                        ++ show (b, b', l, s', m', e')


instance NaturalNumber p => DispFloat (VFloat p) where


instance NaturalNumber p => FShow (VFloat p) where

  fshowsPrec p = showSignedFloat fshowFloat p


instance NaturalNumber p => Show (VFloat p) where

  showsPrec = fshowsPrec


instance NaturalNumber p => Read (VFloat p) where

  readsPrec _ = readSigned readFloat -- FIXME ignores precedence, NaN/Inf fail?


instance HasPrecision VFloat


minimumExponent, maximumExponent :: Int
minimumExponent = negate (bit 20)
maximumExponent =         bit 20

asTypeIn :: (a -> b) -> a
asTypeIn _ = undefined

asTypeOut :: (a -> b) -> b
asTypeOut _ = undefined

asTypeOut2 :: (a -> b -> c) -> c
asTypeOut2 _ = undefined


instance VariablePrecision VFloat where

  adjustPrecision = self
    where
      p = asTypeIn  self
      q = asTypeOut self
      np = floatDigits p
      nq = floatDigits q
      n = nq - np
      self FZero     = FZero
      self FPosInf   = FPosInf
      self FNegInf   = FNegInf
      self FNaN      = FNaN
      self (F m e)
        | n >  0 = encodeVFloat q (m `shiftL` n) (e - n)
        | n == 0 = encodeVFloat q m e
        | n <  0 = encodeVFloat q (m `shiftR` negate n) (e + negate n)
        | otherwise = unreachable


instance Eq (VFloat p) where

  FZero   == FZero   = True
  F a b   == F x y   = a == x && b == y
  FPosInf == FPosInf = True
  FNegInf == FNegInf = True
  -- everything else including NaN
  _       == _       = False

  a       /= x       = not (a == x)


instance Ord (VFloat p) where

  FZero   <  FZero   = False
  FZero   <  F x _   = 0 < x
  F a _   <  FZero   = a < 0
  F a b   <  F x y
    | a > 0 && x > 0 && b <  y = True
    | a > 0 && x > 0 && b == y = a < x
    | a > 0 && x > 0 && b >  y = False
    | a > 0 && x < 0           = False
    | a < 0 && x > 0           = True
    | a < 0 && x < 0 && b <  y = False
    | a < 0 && x < 0 && b == y = a < x
    | a < 0 && x < 0 && b >  y = True
    | otherwise = unreachable
  FNaN    <  _       = False
  _       <  FNaN    = False
  FPosInf <  _       = False
  _       <  FPosInf = True
  _       <  FNegInf = False
  FNegInf <  _       = True

  a       >  x       = x < a

  a       <= x       = a < x || a == x

  a       >= x       = a > x || a == x

  min a@FNaN !_ = a
  min !_ x@FNaN = x
  min a x
    | a <= x    = a
    | otherwise = x

  max a@FNaN !_ = a
  max !_ x@FNaN = x
  max a x
    | a >= x    = a
    | otherwise = x

  -- 'compare' uses default implementation in Ord


instance NaturalNumber p => Num (VFloat p) where

  f@(F a b) + F x y
    | b >  y = encodeVFloat f (a + (x `shiftR` (b - y))) b
    | b == y = encodeVFloat f (a + x) b
    | b <  y = encodeVFloat f ((a `shiftR` (y - b)) + x) y
    | otherwise = unreachable
  a@FNaN  + _       = a
  _       + x@FNaN  = x
  FZero   + x       = x
  a       + FZero   = a
  FPosInf + FNegInf = FNaN
  FNegInf + FPosInf = FNaN
  FPosInf + _       = FPosInf
  _       + FPosInf = FPosInf
  FNegInf + _       = FNegInf
  _       + FNegInf = FNegInf

  f@(F a b) - F x y
    | b >  y = encodeVFloat f (a - (x `shiftR` (b - y))) b
    | b == y = encodeVFloat f (a - x) b
    | b <  y = encodeVFloat f ((a `shiftR` (y - b)) - x) y
    | otherwise = unreachable
  a@FNaN  - _       = a
  _       - x@FNaN  = x
  FZero   - x       = negate x
  a       - FZero   = a
  FPosInf - FPosInf = FNaN
  FNegInf - FNegInf = FNaN
  FPosInf - _       = FPosInf
  _       - FPosInf = FNegInf
  FNegInf - _       = FNegInf
  _       - FNegInf = FPosInf

  negate (F a b) = checkVFloat "negate" $ F (negate a) b
  negate FZero   = FZero
  negate FPosInf = FNegInf
  negate FNegInf = FPosInf
  negate a@FNaN  = a

  abs !a
    | a < 0     = negate a
    | otherwise = a

  signum !a
    | a < 0     = -1
    | a > 0     =  1
    | otherwise =  a

  f@(F a b) * F x y   = encodeVFloat f ((a * x) `shiftR` (k - 1)) (b + y + k - 1) where k = fromIntegral $ precision f
  a@FNaN    * _       = a
  _         * x@FNaN  = x
  FZero     * FPosInf = FNaN
  FZero     * FNegInf = FNaN
  FZero     * _       = FZero
  FPosInf   * FZero   = FNaN
  FNegInf   * FZero   = FNaN
  _         * FZero   = FZero
  a         * x
    | sameSign a x    = FPosInf
    | otherwise       = FNegInf

  fromInteger !i = encodeFloat i 0


instance NaturalNumber p => Real (VFloat p) where

  toRational FZero = 0
  toRational (F m e)
    | e >  0 = fromInteger (m `shiftL` e)
    | e == 0 = fromInteger m
    | e <  0 = m % bit (negate e)
    | otherwise = unreachable
  toRational FPosInf =   1  % 0
  toRational FNegInf = (-1) % 0
  toRational FNaN    =   0  % 0


instance NaturalNumber p => Fractional (VFloat p) where

  f@(F _ _) / g@(F _ _) = f * recip g
  a@FNaN    / _       = a
  _         / x@FNaN  = x
  FPosInf   / FPosInf = FNaN
  FPosInf   / FNegInf = FNaN
  FNegInf   / FPosInf = FNaN
  FNegInf   / FNegInf = FNaN
  _         / FPosInf = FZero
  _         / FNegInf = FZero
  a         / FZero
    | a > 0           = FPosInf
    | a < 0           = FNegInf
    | otherwise       = FNaN
  FZero     / _       = FZero
  a         / x
    | a `sameSign` x  = FPosInf
    | otherwise       = FNegInf

  recip a@FNaN = a
  recip FZero   = FPosInf
  recip FPosInf = FZero
  recip FNegInf = FZero
  recip f@(F m e) = encodeVFloat f (bit k `quot` m) (negate (k + e)) where k = 2 * fromIntegral (precision f)

  fromRational r = fromInteger (numerator r) / fromInteger (denominator r) -- FIXME accuracy


instance NaturalNumber p => RealFrac (VFloat p) where

  properFraction = self
    where
      p = fromIntegral $ precision (asTypeIn self)
      self FZero = (0, FZero)
      self me@(F m e)
        | e >= 0 = (fromInteger m, FZero)
        | e < negate p = (0, me)
        | otherwise = (fromInteger n', f')
        where
          n = m `shiftR` (negate e)
          d = checkVFloat "properFraction" $ F (n `shiftL` (negate e)) e
          f = me - d
          (n', f')
            | (m >= 0) == (f >= 0) = (n, f)
            | otherwise = (n + 1, f - 1)
      self f = (error $ "Numeric.VariablePrecision.Float.properFraction: not finite: " ++ show f, f)

  -- 'truncate' uses default implementation in RealFrac

  -- 'floor' uses default implementation in RealFrac

  -- 'ceiling' uses default implementation in RealFrac

  -- 'round' uses default implementation in RealFrac


instance NaturalNumber p => RealFloat (VFloat p) where

  floatRadix _ = 2

  floatDigits = self
    where
      prec = fromIntegral $ precision (asTypeIn self)
      self = const prec

  floatRange = const (minimumExponent, maximumExponent) -- FIXME arbitrary
  -- this floatRange is somewhat arbitrary, but toInteger gives integers
  -- with up to around (precision + maxExponent) bits, the value here
  -- gives rise to potentially more than 300k decimal digits...

  isNaN FNaN = True
  isNaN _    = False

  isInfinite FPosInf = True
  isInfinite FNegInf = True
  isInfinite _       = False

  isDenormalized _ = False

  isNegativeZero _ = False

  isIEEE _ = False -- FIXME what does this mean?

  decodeFloat FZero   = (0, 0)
  decodeFloat (F m e) = (m, e)
  decodeFloat f = error $ "Numeric.VariablePrecision.Float.decodeFloat: not finite: " ++ show f

  encodeFloat = self
    where
      self = encodeVFloat (undefined `asTypeOf` asTypeOut2 self)

  exponent = self
    where
      prec = fromIntegral $ precision (asTypeIn self)
      self FZero   = 0
      self (F _ e) = e + prec
      self f = error $ "Numeric.VariablePrecision.Float.exponent: not finite: " ++ show f

  significand = self
    where
      prec = fromIntegral $ precision (asTypeIn self)
      e = negate prec
      self (F m _) = checkVFloat "significand" $ F m e
      self f = f

  scaleFloat n (F m e) = checkVFloat "scaleFloat" $ F m (e + n)
  scaleFloat _ f = f

  -- 'atan2' uses default implementation in RealFloat


shift :: Integer -> Int -> Integer
shift !n !k
  | k >  0 = n `shiftL` k
  | k == 0 = n
  | k <  0 = n `shiftR` (negate k)
  | otherwise = unreachable


instance NaturalNumber p => Floating (VFloat p) where -- FIXME

  pi   = genericPi 2

  sqrt = genericSqrt 2

  exp  = genericExp 2

  log  = self
    where
      log2 = genericLog2 2
      self = genericLog' 2 log2

  -- '(**)' uses default implementation in Floating

  -- 'logBase' uses default implementation in Floating

  sin = viaDouble sin -- FIXME

  cos = viaDouble cos -- FIXME

  tan = viaDouble tan -- FIXME

  sinh = viaDouble sinh -- FIXME

  cosh = viaDouble cosh -- FIXME

  tanh = viaDouble tanh -- FIXME

  asin = viaDouble asin -- FIXME

  acos = viaDouble acos -- FIXME

  atan = viaDouble atan -- FIXME

  asinh = viaDouble asinh -- FIXME

  acosh = viaDouble acosh -- FIXME

  atanh = viaDouble atanh -- FIXME


-- despite the name, using this is vital for correct behaviour
-- because it properly handles underflow and overflow as well as
-- checking that the invariant for F holds
checkVFloat :: NaturalNumber p => String -> VFloat p -> VFloat p
checkVFloat = self
  where
    prec = fromIntegral $ precision (asTypeOut2 self)
    prec' = prec - 1
    elo = minimumExponent
    ehi = maximumExponent
    self s x@(F m e)
      | not mok   = error $ "Numeric.VariablePrecision.Float.checkVFloat." ++ s ++ ": internal error (please report this bug): " ++ show ((m, am, lm, prec, prec', mok), (elo, e, ehi, eok))
      | eok       = x
      | e < elo   = FZero   -- underflow
      | m > 0     = FPosInf -- overflow
      | m < 0     = FNegInf -- overflow
      | otherwise = unreachable
      where
        eok = elo <= e  && e  <= ehi
        mok = lm == prec'
        lm = integerLog2 am
        am  = abs m
    self _ x = x


-- | A selection of norms.
class HasPrecision t => Normed t where
  norm1        :: NaturalNumber p => t p -> VFloat p
  norm2        :: NaturalNumber p => t p -> VFloat p
  norm2Squared :: NaturalNumber p => t p -> VFloat p
  normInfinity :: NaturalNumber p => t p -> VFloat p


instance Normed VFloat where
  norm1 = abs
  norm2 = abs
  norm2Squared x = x * x
  normInfinity = abs


-- | A measure of meaningful precision in the difference of two
--   finite non-zero values.
--
--   Values of very different magnitude have little meaningful
--   difference, because @a + b `approxEq` a@ when @|a| >> |b|@.
--
--   Very close values have little meaningful difference,
--   because @a + (a - b) `approxEq` a@ as @|a| >> |a - b|@.
--
--   'effectivePrecisionWith' attempts to quantify this.
--
effectivePrecisionWith :: (Num t, RealFloat r) => (t -> r) {- ^ norm -} -> t -> t -> Int
effectivePrecisionWith n i j
  | t a && t b && t c = p - (d `max` (e - d))
  | otherwise = 0
  where
    t k = k > 0 && not (isInfinite k)
    d = (x `max` y) - z
    e = abs (x - y) `min` p
    p = floatDigits a
    x = exponent a
    y = exponent b
    z = exponent c
    a = n i
    b = n j
    c = n (i - j)


-- | Much like 'effectivePrecisionWith' combined with 'normInfinity'.
effectivePrecision :: (NaturalNumber p, HasPrecision t, Normed t, Num (t p)) => t p -> t p -> Int
effectivePrecision = effectivePrecisionWith normInfinity
infix 6 `effectivePrecision`


-- | An alias for 'effectivePrecision'.
(-@?) :: (NaturalNumber p, HasPrecision t, Normed t, Num (t p)) => t p -> t p -> Int
(-@?) = effectivePrecision
infix 6 -@?


unreachable :: a
unreachable = error "Numeric.VariablePrecision.Float: internal error (please report this bug): unreachable code was reached"


-- | A concrete format suitable for storage or wire transmission.
data DFloat
  = DFloat            { dPrecision :: !Word, dMantissa :: !Integer, dExponent :: !Int }
  | DZero             { dPrecision :: !Word }
  | DPositiveInfinity { dPrecision :: !Word }
  | DNegativeInfinity { dPrecision :: !Word }
  | DNotANumber       { dPrecision :: !Word }
  deriving (Eq, Ord, Read, Show, Data, Typeable)

-- | Freeze a 'VFloat'.
toDFloat :: NaturalNumber p => VFloat p -> DFloat
toDFloat f@(F m e) = DFloat            (precision f) m e
toDFloat f@FZero   = DZero             (precision f)
toDFloat f@FPosInf = DPositiveInfinity (precision f)
toDFloat f@FNegInf = DNegativeInfinity (precision f)
toDFloat f@FNaN    = DNotANumber       (precision f)

-- | Thaw a 'DFloat'.  Results in 'Nothing' on precision mismatch.
fromDFloat :: NaturalNumber p => DFloat -> Maybe (VFloat p)
fromDFloat d
  | dPrecision d == precision result = Just result
  | otherwise = Nothing
  where
    result = case d of
      DFloat _ m e -> encodeVFloat undefined m e
      DZero _ -> FZero
      DPositiveInfinity _ -> FPosInf
      DNegativeInfinity _ -> FNegInf
      DNotANumber _ -> FNaN

-- | Thaw a 'DFloat' to its natural precision.
withDFloat :: DFloat -> (forall p . NaturalNumber p => VFloat p -> r) -> r
withDFloat (DFloat p m e) f = reifyPrecision p $ \prec -> f (encodeVFloat undefined m e `atPrecision` prec)
withDFloat d f = unsafeWithDFloat d f

-- | Thaw a 'DFloat' without guaranteeing a well-formed 'VFloat' value.
--   Possibly slightly faster.
unsafeWithDFloat :: DFloat -> (forall p . NaturalNumber p => VFloat p -> r) -> r
unsafeWithDFloat (DFloat        p m e) f = reifyPrecision p $ \prec -> f (F m e   `atPrecision` prec)
unsafeWithDFloat (DZero             p) f = reifyPrecision p $ \prec -> f (FZero   `atPrecision` prec)
unsafeWithDFloat (DPositiveInfinity p) f = reifyPrecision p $ \prec -> f (FPosInf `atPrecision` prec)
unsafeWithDFloat (DNegativeInfinity p) f = reifyPrecision p $ \prec -> f (FNegInf `atPrecision` prec)
unsafeWithDFloat (DNotANumber       p) f = reifyPrecision p $ \prec -> f (FNaN    `atPrecision` prec)