{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wall #-}

module NumHask.Data.LogField
  ( -- * @LogField@
    LogField (),
    logField,
    fromLogField,

    -- ** Isomorphism to log-domain
    logToLogField,
    logFromLogField,

    -- ** Additional operations
    accurateSum,
    accurateProduct,
    pow,
  )
where

import Data.Data (Data)
import qualified Data.Foldable as F
import GHC.Generics (Generic, Generic1)
import NumHask.Algebra.Abstract.Additive
import NumHask.Algebra.Abstract.Field
import NumHask.Algebra.Abstract.Lattice
import NumHask.Algebra.Abstract.Multiplicative
import NumHask.Algebra.Abstract.Ring
import NumHask.Analysis.Metric
import NumHask.Data.Integral
import NumHask.Data.Rational
import Prelude hiding (Num (..), exp, log, negate)

-- LogField is adapted from LogFloat
----------------------------------------------------------------
--                                                  ~ 2015.08.06

-- |
-- Module      :  Data.Number.LogFloat
-- Copyright   :  Copyright (c) 2007--2015 wren gayle romano
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  stable
-- Portability :  portable (with CPP, FFI)
-- Link        :  https://hackage.haskell.org/package/logfloat
----------------------------------------------------------------
----------------------------------------------------------------
--

-- | A @LogField@ is just a 'Field' with a special interpretation.
-- The 'LogField' function is presented instead of the constructor,
-- in order to ensure semantic conversion. At present the 'Show'
-- instance will convert back to the normal-domain, and hence will
-- underflow at that point. This behavior may change in the future.
--
-- Because 'logField' performs the semantic conversion, we can use
-- operators which say what we *mean* rather than saying what we're
-- actually doing to the underlying representation. That is,
-- equivalences like the following are true[1] thanks to type-class
-- overloading:
--
-- > logField (p + q) == logField p + logField q
-- > logField (p * q) == logField p * logField q
--
--
-- Performing operations in the log-domain is cheap, prevents
-- underflow, and is otherwise very nice for dealing with miniscule
-- probabilities. However, crossing into and out of the log-domain
-- is expensive and should be avoided as much as possible. In
-- particular, if you're doing a series of multiplications as in
-- @lp * LogField q * LogField r@ it's faster to do @lp * LogField
-- (q * r)@ if you're reasonably sure the normal-domain multiplication
-- won't underflow; because that way you enter the log-domain only
-- once, instead of twice. Also note that, for precision, if you're
-- doing more than a few multiplications in the log-domain, you
-- should use 'product' rather than using '(*)' repeatedly.
--
-- Even more particularly, you should /avoid addition/ whenever
-- possible. Addition is provided because sometimes we need it, and
-- the proper implementation is not immediately apparent. However,
-- between two @LogField@s addition requires crossing the exp\/log
-- boundary twice; with a @LogField@ and a 'Double' it's three
-- times, since the regular number needs to enter the log-domain
-- first. This makes addition incredibly slow. Again, if you can
-- parenthesize to do normal-domain operations first, do it!
--
-- [1] That is, true up-to underflow and floating point fuzziness.
-- Which is, of course, the whole point of this module.
newtype LogField a
  = LogField a
  deriving
    ( Eq,
      Ord,
      Read,
      Data,
      Generic,
      Generic1,
      Functor,
      Foldable,
      Traversable
    )

----------------------------------------------------------------
-- To show it, we want to show the normal-domain value rather than
-- the log-domain value. Also, if someone managed to break our
-- invariants (e.g. by passing in a negative and noone's pulled on
-- the thunk yet) then we want to crash before printing the
-- constructor, rather than after.  N.B. This means the show will
-- underflow\/overflow in the same places as normal doubles since
-- we underflow at the @exp@. Perhaps this means we should show the
-- log-domain value instead.
instance (ExpField a, Show a) => Show (LogField a) where
  showsPrec p (LogField x) =
    let y = exp x
     in y `seq` showParen (p > 9) (showString "LogField " . showsPrec 11 y)

----------------------------------------------------------------

-- | Constructor which does semantic conversion from normal-domain
-- to log-domain. Throws errors on negative and NaN inputs. If @p@
-- is non-negative, then following equivalence holds:
--
-- > logField p == logToLogField (log p)
logField :: (ExpField a) => a -> LogField a
{-# INLINE [0] logField #-}
logField = LogField . log

-- | Constructor which assumes the argument is already in the
-- log-domain.
logToLogField :: a -> LogField a
logToLogField = LogField

-- | Semantically convert our log-domain value back into the
-- normal-domain. Beware of overflow\/underflow. The following
-- equivalence holds (without qualification):
--
-- > fromLogField == exp . logFromLogField
fromLogField :: ExpField a => LogField a -> a
{-# INLINE [0] fromLogField #-}
fromLogField (LogField x) = exp x

-- | Return the log-domain value itself without conversion.
logFromLogField :: LogField a -> a
logFromLogField (LogField x) = x

-- These are our module-specific versions of "log\/exp" and "exp\/log";
-- They do the same things but also have a @LogField@ in between
-- the logarithm and exponentiation. In order to ensure these rules
-- fire, we have to delay the inlining on two of the four
-- con-\/destructors.
{-# RULES
"log/fromLogField" forall x.
  log (fromLogField x) =
    logFromLogField x
"fromLogField/LogField" forall x. fromLogField (LogField x) = x
  #-}

log1p :: ExpField a => a -> a
{-# INLINE [0] log1p #-}
log1p x = log (one + x)

expm1 :: (ExpField a) => a -> a
{-# INLINE [0] expm1 #-}
expm1 x = exp x - one

{-# RULES
"expm1/log1p" forall x. expm1 (log1p x) = x
"log1p/expm1" forall x. log1p (expm1 x) = x
  #-}

instance
  (ExpField a, LowerBoundedField a, Ord a) =>
  Additive (LogField a)
  where
  x@(LogField x') + y@(LogField y')
    | x == zero && y == zero = zero
    | x == zero = y
    | y == zero = x
    | x >= y = LogField (x' + log1p (exp (y' - x')))
    | otherwise = LogField (y' + log1p (exp (x' - y')))

  zero = LogField negInfinity

instance
  (ExpField a, Ord a, LowerBoundedField a, UpperBoundedField a) =>
  Subtractive (LogField a)
  where
  negate x
    | x == zero = zero
    | otherwise = nan

instance
  (LowerBoundedField a, Eq a) =>
  Multiplicative (LogField a)
  where
  (LogField x) * (LogField y)
    | x == negInfinity || y == negInfinity = LogField negInfinity
    | otherwise = LogField (x + y)

  one = LogField zero

instance
  (LowerBoundedField a, Eq a) =>
  Divisive (LogField a)
  where
  recip (LogField x) = LogField $ negate x

instance
  (Ord a, LowerBoundedField a, ExpField a) =>
  Distributive (LogField a)

instance (Field (LogField a), ExpField a, LowerBoundedField a, Ord a) => ExpField (LogField a) where
  exp (LogField x) = LogField $ exp x
  log (LogField x) = LogField $ log x
  (**) x (LogField y) = pow x $ exp y

instance (FromIntegral a b, ExpField a) => FromIntegral (LogField a) b where
  fromIntegral_ = logField . fromIntegral_

instance (ToIntegral a b, ExpField a) => ToIntegral (LogField a) b where
  toIntegral = toIntegral . fromLogField

instance (FromRatio a b, ExpField a) => FromRatio (LogField a) b where
  fromRatio = logField . fromRatio

instance (ToRatio a b, ExpField a) => ToRatio (LogField a) b where
  toRatio = toRatio . fromLogField

instance (Ord a) => JoinSemiLattice (LogField a) where
  (\/) = min

instance (Ord a) => MeetSemiLattice (LogField a) where
  (/\) = max

instance
  (Epsilon a, ExpField a, LowerBoundedField a, UpperBoundedField a, Ord a) =>
  Epsilon (LogField a)
  where
  epsilon = logField epsilon
  nearZero (LogField x) = nearZero $ exp x
  aboutEqual (LogField x) (LogField y) = aboutEqual (exp x) (exp y)

instance (Ord a, ExpField a, LowerBoundedField a, UpperBoundedField a) => Field (LogField a)

instance
  (Ord a, ExpField a, LowerBoundedField a, UpperBoundedField a) =>
  LowerBoundedField (LogField a)

instance
  (Ord a, ExpField a, LowerBoundedField a) =>
  IntegralDomain (LogField a)

instance
  (Ord a, ExpField a, LowerBoundedField a, UpperBoundedField a) =>
  UpperBoundedField (LogField a)

instance
  (Ord a, LowerBoundedField a, UpperBoundedField a, ExpField a) =>
  Signed (LogField a)
  where
  sign a
    | a == negInfinity = zero
    | otherwise = one
  abs = id

----------------------------------------------------------------

-- | /O(1)/. Compute powers in the log-domain; that is, the following
-- equivalence holds (modulo underflow and all that):
--
-- > LogField (p ** m) == LogField p `pow` m
--
-- /Since: 0.13/
pow :: (ExpField a, LowerBoundedField a, Ord a) => LogField a -> a -> LogField a
{-# INLINE pow #-}

infixr 8 `pow`

pow x@(LogField x') m
  | x == zero && m == zero = LogField zero
  | x == zero = x
  | otherwise = LogField $ m * x'

-- Some good test cases:
-- for @logsumexp == log . accurateSum . map exp@:
--     logsumexp[0,1,0] should be about 1.55
-- for correctness of avoiding underflow:
--     logsumexp[1000,1001,1000]   ~~ 1001.55 ==  1000 + 1.55
--     logsumexp[-1000,-999,-1000] ~~ -998.45 == -1000 + 1.55
--

-- | /O(n)/. Compute the sum of a finite list of 'LogField's, being
-- careful to avoid underflow issues. That is, the following
-- equivalence holds (modulo underflow and all that):
--
-- > LogField . accurateSum == accurateSum . map LogField
--
-- /N.B./, this function requires two passes over the input. Thus,
-- it is not amenable to list fusion, and hence will use a lot of
-- memory when summing long lists.
{-# INLINE accurateSum #-}
accurateSum :: (ExpField a, Foldable f, Ord a) => f (LogField a) -> LogField a
accurateSum xs = LogField (theMax + log theSum)
  where
    LogField theMax = maximum xs
    -- compute @\log \sum_{x \in xs} \exp(x - theMax)@
    theSum = F.foldl' (\acc (LogField x) -> acc + exp (x - theMax)) zero xs

-- | /O(n)/. Compute the product of a finite list of 'LogField's,
-- being careful to avoid numerical error due to loss of precision.
-- That is, the following equivalence holds (modulo underflow and
-- all that):
--
-- > LogField . accurateProduct == accurateProduct . map LogField
{-# INLINE accurateProduct #-}
accurateProduct :: (ExpField a, Foldable f) => f (LogField a) -> LogField a
accurateProduct = LogField . fst . F.foldr kahanPlus (zero, zero)
  where
    kahanPlus (LogField x) (t, c) =
      let y = x - c
          t' = t + y
          c' = (t' - t) - y
       in (t', c')