{-# LANGUAGE DeriveGeneric, DeriveDataTypeable, DeriveFunctor, GeneralizedNewtypeDeriving, DeriveFoldable, DeriveTraversable, GADTs #-} {-# LANGUAGE FlexibleInstances, FlexibleContexts, UndecidableInstances, MultiParamTypeClasses #-} module NumHask.Data.LogField ( -- * @LogField@ LogField() -- ** Isomorphism to normal-domain , logField , fromLogField -- ** Isomorphism to log-domain , logToLogField , logFromLogField -- ** Additional operations , accurateSum, accurateProduct , pow )where import GHC.Generics ( Generic , Generic1 ) import Data.Data ( Data ) import NumHask.Algebra.Additive import NumHask.Algebra.Multiplicative import NumHask.Algebra.Distribution import NumHask.Algebra.Field import NumHask.Algebra.Integral import NumHask.Algebra.Rational import NumHask.Algebra.Metric import Prelude hiding ( Num(..) , negate , sin , cos , sqrt , (/) , atan , pi , exp , log , recip , (**) , toInteger ) import qualified Data.Foldable as F -- LogField is adapted from LogFloat ---------------------------------------------------------------- -- ~ 2015.08.06 -- | -- Module : Data.Number.LogFloat -- Copyright : Copyright (c) 2007--2015 wren gayle romano -- License : BSD3 -- Maintainer : wren@community.haskell.org -- Stability : stable -- Portability : portable (with CPP, FFI) -- Link : https://hackage.haskell.org/package/logfloat ---------------------------------------------------------------- ---------------------------------------------------------------- -- -- | A @LogField@ is just a 'Field' with a special interpretation. -- The 'LogField' function is presented instead of the constructor, -- in order to ensure semantic conversion. At present the 'Show' -- instance will convert back to the normal-domain, and hence will -- underflow at that point. This behavior may change in the future. -- -- Because 'logField' performs the semantic conversion, we can use -- operators which say what we *mean* rather than saying what we're -- actually doing to the underlying representation. That is, -- equivalences like the following are true[1] thanks to type-class -- overloading: -- -- > logField (p + q) == logField p + logField q -- > logField (p * q) == logField p * logField q -- -- -- Performing operations in the log-domain is cheap, prevents -- underflow, and is otherwise very nice for dealing with miniscule -- probabilities. However, crossing into and out of the log-domain -- is expensive and should be avoided as much as possible. In -- particular, if you're doing a series of multiplications as in -- @lp * LogField q * LogField r@ it's faster to do @lp * LogField -- (q * r)@ if you're reasonably sure the normal-domain multiplication -- won't underflow; because that way you enter the log-domain only -- once, instead of twice. Also note that, for precision, if you're -- doing more than a few multiplications in the log-domain, you -- should use 'product' rather than using '(*)' repeatedly. -- -- Even more particularly, you should /avoid addition/ whenever -- possible. Addition is provided because sometimes we need it, and -- the proper implementation is not immediately apparent. However, -- between two @LogField@s addition requires crossing the exp\/log -- boundary twice; with a @LogField@ and a 'Double' it's three -- times, since the regular number needs to enter the log-domain -- first. This makes addition incredibly slow. Again, if you can -- parenthesize to do normal-domain operations first, do it! -- -- [1] That is, true up-to underflow and floating point fuzziness. -- Which is, of course, the whole point of this module. newtype LogField a = LogField a deriving (Eq, Ord, Read, Data, Generic, Generic1, Functor, Foldable, Traversable) ---------------------------------------------------------------- -- To show it, we want to show the normal-domain value rather than -- the log-domain value. Also, if someone managed to break our -- invariants (e.g. by passing in a negative and noone's pulled on -- the thunk yet) then we want to crash before printing the -- constructor, rather than after. N.B. This means the show will -- underflow\/overflow in the same places as normal doubles since -- we underflow at the @exp@. Perhaps this means we should show the -- log-domain value instead. instance (ExpField a, Show a) => Show (LogField a) where showsPrec p (LogField x) = let y = exp x in y `seq` showParen (p > 9) ( showString "LogField " . showsPrec 11 y ) ---------------------------------------------------------------- -- | Constructor which does semantic conversion from normal-domain -- to log-domain. Throws errors on negative and NaN inputs. If @p@ -- is non-negative, then following equivalence holds: -- -- > logField p == logToLogField (log p) logField :: (ExpField a) => a -> LogField a {-# INLINE [0] logField #-} logField = LogField . log -- TODO: figure out what to do here, removed guards -- | Constructor which assumes the argument is already in the -- log-domain. logToLogField :: a -> LogField a logToLogField = LogField -- | Semantically convert our log-domain value back into the -- normal-domain. Beware of overflow\/underflow. The following -- equivalence holds (without qualification): -- -- > fromLogField == exp . logFromLogField -- fromLogField :: ExpField a => LogField a -> a {-# INLINE [0] fromLogField #-} fromLogField (LogField x) = exp x -- | Return the log-domain value itself without conversion. logFromLogField :: LogField a -> a logFromLogField (LogField x) = x -- These are our module-specific versions of "log\/exp" and "exp\/log"; -- They do the same things but also have a @LogField@ in between -- the logarithm and exponentiation. In order to ensure these rules -- fire, we have to delay the inlining on two of the four -- con-\/destructors. {-# RULES -- Out of log-domain and back in "log/fromLogField" forall x. log (fromLogField x) = logFromLogField x -- TODO: Rewrite-rule too complicated "LogField/fromLogField" forall x. LogField (fromLogField x) = x -- Into log-domain and back out "fromLogField/LogField" forall x. fromLogField (LogField x) = x #-} log1p :: ExpField a => a -> a {-# INLINE [0] log1p #-} log1p x = log (one + x) expm1 :: ExpField a => a -> a {-# INLINE [0] expm1 #-} expm1 x = exp x - one {-# RULES -- Into log-domain and back out "expm1/log1p" forall x. expm1 (log1p x) = x -- Out of log-domain and back in "log1p/expm1" forall x. log1p (expm1 x) = x #-} instance (ExpField a, LowerBoundedField a, Ord a) => AdditiveMagma (LogField a) where x@(LogField x') `plus` y@(LogField y') | x == zero && y == zero = zero | x == zero = y | y == zero = x | x >= y = LogField (x' + log1p (exp (y' - x'))) | otherwise = LogField (y' + log1p (exp (x' - y'))) instance (LowerBoundedField a, ExpField a, Ord a) => AdditiveUnital (LogField a) where zero = LogField negInfinity instance (LowerBoundedField a, ExpField a, Ord a) => AdditiveAssociative (LogField a) instance (LowerBoundedField a,ExpField a, Ord a) => AdditiveCommutative (LogField a) instance (LowerBoundedField a, ExpField a, Ord a) => Additive (LogField a) instance (AdditiveMagma a, LowerBoundedField a, Eq a) => MultiplicativeMagma (LogField a) where (LogField x) `times ` (LogField y) | x == negInfinity || y == negInfinity = LogField negInfinity | otherwise = LogField (x `plus` y) instance (AdditiveUnital a, LowerBoundedField a, Eq a) => MultiplicativeUnital (LogField a) where one = LogField zero instance (AdditiveAssociative a, LowerBoundedField a, Eq a) => MultiplicativeAssociative (LogField a) instance (AdditiveCommutative a, LowerBoundedField a, Eq a) => MultiplicativeCommutative (LogField a) instance (AdditiveInvertible a, LowerBoundedField a, Eq a) => MultiplicativeInvertible (LogField a) where recip (LogField x) = LogField $ negate x instance (AdditiveUnital a , AdditiveAssociative a , AdditiveCommutative a , Additive a , LowerBoundedField a , Eq a) => Multiplicative (LogField a) instance (AdditiveUnital a , AdditiveAssociative a , AdditiveInvertible a , AdditiveLeftCancellative a , LowerBoundedField a , Eq a) => MultiplicativeLeftCancellative (LogField a) instance (AdditiveUnital a , AdditiveAssociative a , AdditiveInvertible a , AdditiveRightCancellative a , LowerBoundedField a , Eq a) => MultiplicativeRightCancellative (LogField a) instance (Multiplicative (LogField a), AdditiveInvertible a, AdditiveGroup a, LowerBoundedField a, Eq a) => MultiplicativeGroup (LogField a) instance (LowerBoundedField a, ExpField a, Ord a, AdditiveMagma a) => Distribution (LogField a) -- unable to provide this instance because there is no Field (LogField a) instance -- instance (Field (LogField a), ExpField a, LowerBoundedField a, Ord a) => ExpField (LogField a) where -- exp (LogField x) = (LogField $ exp x) -- log (LogField x) = (LogField $ log x) -- (**) x (LogField y) = pow x $ exp y instance (FromInteger a, ExpField a) => FromInteger (LogField a) where fromInteger = logField . fromInteger instance (ToInteger a, ExpField a) => ToInteger (LogField a) where toInteger = toInteger . fromLogField instance (FromRatio a, ExpField a) => FromRatio (LogField a) where fromRatio = logField . fromRatio instance (ToRatio a, ExpField a) => ToRatio (LogField a) where toRatio = toRatio . fromLogField instance (Epsilon a, ExpField a, LowerBoundedField a, Ord a) => Epsilon (LogField a) where nearZero (LogField x) = nearZero $ exp x aboutEqual (LogField x) (LogField y) = aboutEqual (exp x) (exp y) ---------------------------------------------------------------- -- | /O(1)/. Compute powers in the log-domain; that is, the following -- equivalence holds (modulo underflow and all that): -- -- > LogField (p ** m) == LogField p `pow` m -- -- /Since: 0.13/ pow :: (ExpField a, LowerBoundedField a, Ord a) => LogField a -> a -> LogField a {-# INLINE pow #-} infixr 8 `pow` pow x@(LogField x') m | x == zero && m == zero = LogField zero | x == zero = x | otherwise = LogField $ m * x' -- Some good test cases: -- for @logsumexp == log . accurateSum . map exp@: -- logsumexp[0,1,0] should be about 1.55 -- for correctness of avoiding underflow: -- logsumexp[1000,1001,1000] ~~ 1001.55 == 1000 + 1.55 -- logsumexp[-1000,-999,-1000] ~~ -998.45 == -1000 + 1.55 -- -- | /O(n)/. Compute the sum of a finite list of 'LogField's, being -- careful to avoid underflow issues. That is, the following -- equivalence holds (modulo underflow and all that): -- -- > LogField . accurateSum == accurateSum . map LogField -- -- /N.B./, this function requires two passes over the input. Thus, -- it is not amenable to list fusion, and hence will use a lot of -- memory when summing long lists. {-# INLINE accurateSum #-} accurateSum :: (ExpField a, Foldable f, Ord a) => f (LogField a) -> LogField a accurateSum xs = LogField (theMax + log theSum) where LogField theMax = maximum xs -- compute @\log \sum_{x \in xs} \exp(x - theMax)@ theSum = F.foldl' (\acc (LogField x) -> acc + exp (x - theMax)) zero xs -- | /O(n)/. Compute the product of a finite list of 'LogField's, -- being careful to avoid numerical error due to loss of precision. -- That is, the following equivalence holds (modulo underflow and -- all that): -- -- > LogField . accurateProduct == accurateProduct . map LogField {-# INLINE accurateProduct #-} accurateProduct :: (ExpField a, Foldable f) => f (LogField a) -> LogField a accurateProduct = LogField . fst . F.foldr kahanPlus (zero, zero) where kahanPlus (LogField x) (t, c) = let y = x - c t' = t + y c' = (t' - t) - y in (t', c') -- This version *completely* eliminates rounding errors and loss -- of significance due to catastrophic cancellation during summation. -- Also see the other -- implementations given there. For Python's actual C implementation, -- see math_fsum in -- -- -- For merely *mitigating* errors rather than completely eliminating -- them, see . -- -- A good test case is @msum([1, 1e100, 1, -1e100] * 10000) == 20000.0@ {- -- For proof of correctness, see -- def msum(xs): partials = [] # sorted, non-overlapping partial sums # N.B., the actual C implementation uses a 32 array, doubling size as needed for x in xs: i = 0 for y in partials: # for(i = j = 0; j < n; j++) if abs(x) < abs(y): x, y = y, x hi = x + y lo = y - (hi - x) if lo != 0.0: partials[i] = lo i += 1 x = hi # does an append of x while dropping all the partials after # i. The C version does n=i; and leaves the garbage in place partials[i:] = [x] # BUG: this last step isn't entirely correct and can lose # precision return sum(partials, 0.0) -}