module Data.Adaptive (Adaptive(..), fromFloatingPoint, approx, approx', approxFast, splitter, epsilon) where
import Data.List
import Data.Bits
import Data.Ratio
type FloatT = Double
mergeBy _ xs [] = xs
mergeBy _ [] ys = ys
mergeBy cmp (x:xs) (y:ys) = case cmp x y of
GT -> y : mergeBy cmp (x:xs) ys
_ -> x : mergeBy cmp xs (y:ys)
newtype Adaptive a = Adaptive [a]
instance (Show a, RealFloat a) => Show (Adaptive a) where
showsPrec n x = showParen (n > 6 && x < 0) $ showsPrec 0 (approx x) . ('~':)
liftAdaptive1 f = fromFloatingPoint . f . approxFast
liftAdaptive2 f x y = fromFloatingPoint $ approxFast x `f` approxFast y
instance (RealFloat a, Real a) => Real (Adaptive a) where
toRational = toRational . approxFast
instance (RealFloat a, Floating a) => Floating (Adaptive a) where
pi = fromFloatingPoint pi
exp = liftAdaptive1 exp
sqrt = sqrtA
log = liftAdaptive1 log
(**) = liftAdaptive2 (**)
logBase = liftAdaptive2 logBase
sin = liftAdaptive1 sin
tan = liftAdaptive1 tan
cos = liftAdaptive1 cos
asin = liftAdaptive1 asin
atan = liftAdaptive1 atan
acos = liftAdaptive1 acos
sinh = liftAdaptive1 sinh
tanh = liftAdaptive1 tanh
cosh = liftAdaptive1 cosh
asinh = liftAdaptive1 asinh
atanh = liftAdaptive1 atanh
acosh = liftAdaptive1 acosh
instance (RealFloat a) => RealFloat (Adaptive a) where
floatRadix = floatRadix . approxFast
floatDigits = floatDigits . approxFast
floatRange = floatRange . approxFast
decodeFloat = decodeFloat . approxFast
encodeFloat s r = fromFloatingPoint (encodeFloat s r)
exponent = exponent . approxFast
significand = liftAdaptive1 significand
scaleFloat x = liftAdaptive1 (scaleFloat x)
isNaN = isNaN . approxFast
isInfinite = isInfinite . approxFast
isDenormalized = isDenormalized . approxFast
isNegativeZero = isNegativeZero . approxFast
isIEEE = isIEEE . approxFast
atan2 = liftAdaptive2 atan2
instance (RealFloat a, RealFrac a) => RealFrac (Adaptive a) where
properFraction x = (y, fromFloatingPoint x')
where (y, x') = properFraction (approxFast x)
instance (Num a, RealFloat a) => Eq (Adaptive a) where
a == b | null x = True
| otherwise = False
where Adaptive x = a b
instance (Num a, RealFloat a) => Ord (Adaptive a) where
compare a b | null x = EQ
| head x > 0 = GT
| otherwise = LT
where Adaptive x = a b
instance (Num a, RealFloat a) => Num (Adaptive a) where
Adaptive x + Adaptive y = Adaptive . compress $ fastExpSum x y
Adaptive x Adaptive y = Adaptive x + Adaptive (map negate y)
Adaptive x * Adaptive y = Adaptive $ expProd x y
negate (Adaptive x) = Adaptive (map negate x)
abs (Adaptive []) = Adaptive []
abs (Adaptive x) | head x < 0 = Adaptive (map negate x)
| otherwise = Adaptive x
signum (Adaptive []) = Adaptive []
signum (Adaptive (x:_)) = Adaptive [signum x]
fromInteger 0 = Adaptive []
fromInteger x = Adaptive [fromInteger x]
instance (RealFloat a, Fractional a) => Fractional (Adaptive a) where
fromRational r = fromIntegral (numerator r) / fromIntegral (denominator r)
a / b = Adaptive $ a `divA` b
divA x y | d == 0 = []
| e == 0 || abs e >= abs x = [d]
| otherwise = d : divA e y
where d = approxFast x / approxFast y
e = x fromFloatingPoint d * y
sqrtA x = r + bab r
where r = liftAdaptive1 sqrt x
bab e | c == 0 = 0
| otherwise = c + bab (e + c)
where c = (x / e e) / 2
epsilon :: (RealFloat a) => a
epsilon = x
where x = scaleFloat ( floatDigits x) 1
splitter :: (RealFloat a) => a
splitter = s
where s = scaleFloat ((floatDigits s + 1) `shiftR` 1) 1 + 1
fromFloatingPoint :: (RealFloat a) => a -> Adaptive a
fromFloatingPoint 0 = Adaptive []
fromFloatingPoint x = Adaptive [x]
approx :: (RealFloat a, Ord a) => Adaptive a -> a
approx (Adaptive []) = 0
approx (Adaptive (x:xs)) = foldr (+) x . takeWhile ((> x*epsilon).abs) $ xs
approx' :: (Real a, RealFloat b) => Adaptive a -> b
approx' (Adaptive []) = 0
approx' (Adaptive (x:xs)) = foldr (+) x' .
takeWhile ((> x' * epsilon) . abs) .
map realToFrac $ xs
where x' = realToFrac x
approxFast :: (Num a) => Adaptive a -> a
approxFast (Adaptive []) = 0
approxFast (Adaptive (x:_)) = x
fastTwoSum :: (Num a) => a -> a -> (# a, a #)
fastTwoSum a b = (# x, y #)
where x = a + b
b' = x a
y = b b'
twoSum :: (Num a) => a -> a -> (# a, a #)
twoSum !a !b = (# x, y #)
where !x = a + b
!b' = x a
!a' = x b'
!br = b b'
!ar = a a'
!y = ar + br
fastExpSum :: (RealFloat a) => [a] -> [a] -> [a]
fastExpSum [] x = x
fastExpSum x [] = x
fastExpSum e f = filter (/= 0) (q:hs)
where g = mergeBy cmp e f
cmp x y = compare (exponent y) (exponent x)
(# q, hs #) = mapAccumR1 twoSum g
mapAccumR1 f [x] = (# x, [] #)
mapAccumR1 f (x:xs) = (# a', y:xs' #)
where (# a, xs' #) = mapAccumR1 f xs
(# a', y #) = f a x
split !a = (# ah, al #)
where !c = splitter * a
!ab = c a
!ah = c ab
!al = a ah
twoProd !a !b = (# x, y #)
where !x = a * b
(# !ah, !al #) = split a
(# !bh, !bl #) = split b
!err = x ah * bh al * bh ah * bl
!y = al * bl err
scaleExp _ [] = []
scaleExp b es = filter (/= 0) . (uncurry (:)) . foldr f (q0, [h0]) . init $ es
where (# q0, h0 #) = twoProd (last es) b
f e (q, h) = (q'', h2 : h1 : h)
where (# th, tl #) = twoProd e b
(# q', h1 #) = twoSum q tl
(# q'', h2 #) = fastTwoSum th q'
expProd [] _ = []
expProd _ [] = []
expProd x [y] = scaleExp y x
expProd [x] y = scaleExp x y
expProd (x:xs) ys = foldl' f (scaleExp x ys) xs
where f h e = fastExpSum h . scaleExp e $ ys
compress e = comp up . comp down $ e
where comp _ [] = []
comp _ [x] = [x]
comp f (x:xs) = uncurry (:) . foldl' f (x, []) $ xs
down (q, h) e | ql /= 0 = (ql, qu:h)
| otherwise = (qu, h)
where (# qu, ql #) = fastTwoSum q e
up (q, h) e | ql /= 0 = (qu, ql:h)
| otherwise = (qu, h)
where (# qu, ql #) = fastTwoSum e q