{-# LANGUAGE DeriveGeneric, DeriveDataTypeable, DeriveFunctor, GeneralizedNewtypeDeriving, DeriveFoldable, DeriveTraversable, GADTs #-}
{-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances, MultiParamTypeClasses #-}
module NumHask.Data.LogField
(
-- * @LogField@
LogField()
-- ** Isomorphism to normal-domain
, logField
, fromLogField
-- ** Isomorphism to log-domain
, logToLogField
, logFromLogField
-- ** Additional operations
, accurateSum, accurateProduct
, pow
)where
import GHC.Generics ( Generic
, Generic1
)
import Data.Data ( Data )
import NumHask.Algebra.Additive
import NumHask.Algebra.Multiplicative
import NumHask.Algebra.Distribution
import NumHask.Algebra.Field
import NumHask.Algebra.Integral
import NumHask.Algebra.Rational
import NumHask.Algebra.Metric
import Prelude hiding ( Num(..)
, negate
, sin
, cos
, sqrt
, (/)
, atan
, pi
, exp
, log
, recip
, (**)
, toInteger
)
import qualified Data.Foldable as F
-- LogField is adapted from LogFloat
----------------------------------------------------------------
-- ~ 2015.08.06
-- |
-- Module : Data.Number.LogFloat
-- Copyright : Copyright (c) 2007--2015 wren gayle romano
-- License : BSD3
-- Maintainer : wren@community.haskell.org
-- Stability : stable
-- Portability : portable (with CPP, FFI)
-- Link : https://hackage.haskell.org/package/logfloat
----------------------------------------------------------------
----------------------------------------------------------------
--
-- | A @LogField@ is just a 'Field' with a special interpretation.
-- The 'LogField' function is presented instead of the constructor,
-- in order to ensure semantic conversion. At present the 'Show'
-- instance will convert back to the normal-domain, and hence will
-- underflow at that point. This behavior may change in the future.
--
-- Because 'logField' performs the semantic conversion, we can use
-- operators which say what we *mean* rather than saying what we're
-- actually doing to the underlying representation. That is,
-- equivalences like the following are true[1] thanks to type-class
-- overloading:
--
-- > logField (p + q) == logField p + logField q
-- > logField (p * q) == logField p * logField q
--
--
-- Performing operations in the log-domain is cheap, prevents
-- underflow, and is otherwise very nice for dealing with miniscule
-- probabilities. However, crossing into and out of the log-domain
-- is expensive and should be avoided as much as possible. In
-- particular, if you're doing a series of multiplications as in
-- @lp * LogField q * LogField r@ it's faster to do @lp * LogField
-- (q * r)@ if you're reasonably sure the normal-domain multiplication
-- won't underflow; because that way you enter the log-domain only
-- once, instead of twice. Also note that, for precision, if you're
-- doing more than a few multiplications in the log-domain, you
-- should use 'product' rather than using '(*)' repeatedly.
--
-- Even more particularly, you should /avoid addition/ whenever
-- possible. Addition is provided because sometimes we need it, and
-- the proper implementation is not immediately apparent. However,
-- between two @LogField@s addition requires crossing the exp\/log
-- boundary twice; with a @LogField@ and a 'Double' it's three
-- times, since the regular number needs to enter the log-domain
-- first. This makes addition incredibly slow. Again, if you can
-- parenthesize to do normal-domain operations first, do it!
--
-- [1] That is, true up-to underflow and floating point fuzziness.
-- Which is, of course, the whole point of this module.
newtype LogField a = LogField a
deriving (Eq, Ord, Read, Data, Generic, Generic1, Functor, Foldable, Traversable)
----------------------------------------------------------------
-- To show it, we want to show the normal-domain value rather than
-- the log-domain value. Also, if someone managed to break our
-- invariants (e.g. by passing in a negative and noone's pulled on
-- the thunk yet) then we want to crash before printing the
-- constructor, rather than after. N.B. This means the show will
-- underflow\/overflow in the same places as normal doubles since
-- we underflow at the @exp@. Perhaps this means we should show the
-- log-domain value instead.
instance (ExpField a, Show a) => Show (LogField a) where
showsPrec p (LogField x) =
let y = exp x in y `seq`
showParen (p > 9)
( showString "LogField "
. showsPrec 11 y
)
----------------------------------------------------------------
-- | Constructor which does semantic conversion from normal-domain
-- to log-domain. Throws errors on negative and NaN inputs. If @p@
-- is non-negative, then following equivalence holds:
--
-- > logField p == logToLogField (log p)
logField :: (ExpField a) => a -> LogField a
{-# INLINE [0] logField #-}
logField = LogField . log
-- TODO: figure out what to do here, removed guards
-- | Constructor which assumes the argument is already in the
-- log-domain.
logToLogField :: a -> LogField a
logToLogField = LogField
-- | Semantically convert our log-domain value back into the
-- normal-domain. Beware of overflow\/underflow. The following
-- equivalence holds (without qualification):
--
-- > fromLogField == exp . logFromLogField
--
fromLogField :: ExpField a => LogField a -> a
{-# INLINE [0] fromLogField #-}
fromLogField (LogField x) = exp x
-- | Return the log-domain value itself without conversion.
logFromLogField :: LogField a -> a
logFromLogField (LogField x) = x
-- These are our module-specific versions of "log\/exp" and "exp\/log";
-- They do the same things but also have a @LogField@ in between
-- the logarithm and exponentiation. In order to ensure these rules
-- fire, we have to delay the inlining on two of the four
-- con-\/destructors.
{-# RULES
-- Out of log-domain and back in
"log/fromLogField" forall x. log (fromLogField x) = logFromLogField x
-- TODO: Rewrite-rule too complicated
"LogField/fromLogField" forall x. LogField (fromLogField x) = x
-- Into log-domain and back out
"fromLogField/LogField" forall x. fromLogField (LogField x) = x
#-}
log1p :: ExpField a => a -> a
{-# INLINE [0] log1p #-}
log1p x = log (one + x)
expm1 :: ExpField a => a -> a
{-# INLINE [0] expm1 #-}
expm1 x = exp x - one
{-# RULES
-- Into log-domain and back out
"expm1/log1p" forall x. expm1 (log1p x) = x
-- Out of log-domain and back in
"log1p/expm1" forall x. log1p (expm1 x) = x
#-}
instance (ExpField a, LowerBoundedField a, Ord a) => AdditiveMagma (LogField a) where
x@(LogField x') `plus` y@(LogField y')
| x == zero && y == zero = zero
| x == zero = y
| y == zero = x
| x >= y = LogField (x' + log1p (exp (y' - x')))
| otherwise = LogField (y' + log1p (exp (x' - y')))
instance (LowerBoundedField a, ExpField a, Ord a) => AdditiveUnital (LogField a) where
zero = LogField negInfinity
instance (LowerBoundedField a, ExpField a, Ord a) => AdditiveAssociative (LogField a)
instance (LowerBoundedField a,ExpField a, Ord a) => AdditiveCommutative (LogField a)
instance (LowerBoundedField a, ExpField a, Ord a) => Additive (LogField a)
instance (AdditiveMagma a, LowerBoundedField a, Eq a) => MultiplicativeMagma (LogField a) where
(LogField x) `times ` (LogField y)
| x == negInfinity || y == negInfinity = LogField negInfinity
| otherwise = LogField (x `plus` y)
instance (AdditiveUnital a, LowerBoundedField a, Eq a) => MultiplicativeUnital (LogField a) where
one = LogField zero
instance (AdditiveAssociative a, LowerBoundedField a, Eq a) => MultiplicativeAssociative (LogField a)
instance (AdditiveCommutative a, LowerBoundedField a, Eq a) => MultiplicativeCommutative (LogField a)
instance (AdditiveInvertible a, LowerBoundedField a, Eq a) => MultiplicativeInvertible (LogField a) where
recip (LogField x) = LogField $ negate x
instance (AdditiveUnital a
, AdditiveAssociative a
, AdditiveCommutative a
, Additive a
, LowerBoundedField a
, Eq a) => Multiplicative (LogField a)
instance (AdditiveUnital a
, AdditiveAssociative a
, AdditiveInvertible a
, AdditiveLeftCancellative a
, LowerBoundedField a
, Eq a) => MultiplicativeLeftCancellative (LogField a)
instance (AdditiveUnital a
, AdditiveAssociative a
, AdditiveInvertible a
, AdditiveRightCancellative a
, LowerBoundedField a
, Eq a) => MultiplicativeRightCancellative (LogField a)
instance (Multiplicative (LogField a), AdditiveInvertible a, AdditiveGroup a, LowerBoundedField a, Eq a) => MultiplicativeGroup (LogField a)
instance (LowerBoundedField a, ExpField a, Ord a, AdditiveMagma a) => Distribution (LogField a)
-- unable to provide this instance because there is no Field (LogField a) instance
-- instance (Field (LogField a), ExpField a, LowerBoundedField a, Ord a) => ExpField (LogField a) where
-- exp (LogField x) = (LogField $ exp x)
-- log (LogField x) = (LogField $ log x)
-- (**) x (LogField y) = pow x $ exp y
instance (FromInteger a, ExpField a) => FromInteger (LogField a) where
fromInteger = logField . fromInteger
instance (ToInteger a, ExpField a) => ToInteger (LogField a) where
toInteger = toInteger . fromLogField
instance (FromRatio a, ExpField a) => FromRatio (LogField a) where
fromRatio = logField . fromRatio
instance (ToRatio a, ExpField a) => ToRatio (LogField a) where
toRatio = toRatio . fromLogField
instance (Epsilon a, ExpField a, LowerBoundedField a, Ord a) => Epsilon (LogField a) where
nearZero (LogField x) = nearZero $ exp x
aboutEqual (LogField x) (LogField y) = aboutEqual (exp x) (exp y)
----------------------------------------------------------------
-- | /O(1)/. Compute powers in the log-domain; that is, the following
-- equivalence holds (modulo underflow and all that):
--
-- > LogField (p ** m) == LogField p `pow` m
--
-- /Since: 0.13/
pow :: (ExpField a, LowerBoundedField a, Ord a) => LogField a -> a -> LogField a
{-# INLINE pow #-}
infixr 8 `pow`
pow x@(LogField x') m
| x == zero && m == zero = LogField zero
| x == zero = x
| otherwise = LogField $ m * x'
-- Some good test cases:
-- for @logsumexp == log . accurateSum . map exp@:
-- logsumexp[0,1,0] should be about 1.55
-- for correctness of avoiding underflow:
-- logsumexp[1000,1001,1000] ~~ 1001.55 == 1000 + 1.55
-- logsumexp[-1000,-999,-1000] ~~ -998.45 == -1000 + 1.55
--
-- | /O(n)/. Compute the sum of a finite list of 'LogField's, being
-- careful to avoid underflow issues. That is, the following
-- equivalence holds (modulo underflow and all that):
--
-- > LogField . accurateSum == accurateSum . map LogField
--
-- /N.B./, this function requires two passes over the input. Thus,
-- it is not amenable to list fusion, and hence will use a lot of
-- memory when summing long lists.
{-# INLINE accurateSum #-}
accurateSum :: (ExpField a, Foldable f, Ord a) => f (LogField a) -> LogField a
accurateSum xs = LogField (theMax + log theSum)
where
LogField theMax = maximum xs
-- compute @\log \sum_{x \in xs} \exp(x - theMax)@
theSum = F.foldl' (\acc (LogField x) -> acc + exp (x - theMax)) zero xs
-- | /O(n)/. Compute the product of a finite list of 'LogField's,
-- being careful to avoid numerical error due to loss of precision.
-- That is, the following equivalence holds (modulo underflow and
-- all that):
--
-- > LogField . accurateProduct == accurateProduct . map LogField
{-# INLINE accurateProduct #-}
accurateProduct :: (ExpField a, Foldable f) => f (LogField a) -> LogField a
accurateProduct = LogField . fst . F.foldr kahanPlus (zero, zero)
where
kahanPlus (LogField x) (t, c) =
let y = x - c
t' = t + y
c' = (t' - t) - y
in (t', c')
-- This version *completely* eliminates rounding errors and loss
-- of significance due to catastrophic cancellation during summation.
-- Also see the other
-- implementations given there. For Python's actual C implementation,
-- see math_fsum in
--
--
-- For merely *mitigating* errors rather than completely eliminating
-- them, see .
--
-- A good test case is @msum([1, 1e100, 1, -1e100] * 10000) == 20000.0@
{-
-- For proof of correctness, see
--
def msum(xs):
partials = [] # sorted, non-overlapping partial sums
# N.B., the actual C implementation uses a 32 array, doubling size as needed
for x in xs:
i = 0
for y in partials: # for(i = j = 0; j < n; j++)
if abs(x) < abs(y):
x, y = y, x
hi = x + y
lo = y - (hi - x)
if lo != 0.0:
partials[i] = lo
i += 1
x = hi
# does an append of x while dropping all the partials after
# i. The C version does n=i; and leaves the garbage in place
partials[i:] = [x]
# BUG: this last step isn't entirely correct and can lose
# precision
return sum(partials, 0.0)
-}