{-# LANGUAGE CPP, ForeignFunctionInterface, EmptyDataDecls, FlexibleInstances, FlexibleContexts, TypeFamilies #-}
module Numeric.Rounding
( Round(..)
, Rounding
, Precision
, Up, Down, Trunc, ToNearest
, up, down, trunc
, runUp, runDown, runTrunc
) where
import Control.Applicative
import GHC.Real
import Data.Foldable
import Data.Traversable
import Data.Array
import Foreign
import Foreign.C.Types
import Numeric.Extras
#include
#include
-- TODO: tweak the lsbs of pi
-- modify the enum instance to properly round
-- implement complex numbers
newtype Round dir a = Round a
deriving (Show, Read, Eq, Ord, Bounded)
{-# RULES
"realToFrac/Round->a" realToFrac = \(Round x) -> x
"realToFrac/a->Round" realToFrac = Round
#-}
instance Functor (Round dir) where
fmap f (Round a) = Round (f a)
instance Foldable (Round dir) where
foldMap f (Round a) = f a
instance Traversable (Round dir) where
traverse f (Round a) = Round <$> f a
class Rounding dir where
mode :: Round dir a -> CInt
rounding :: (Integral b, RealFrac a) => Round dir c -> (a -> b)
pi_m :: Precision a => Ops a -> Round dir a
data ToNearest
data Trunc
data Up
data Down
instance Rounding ToNearest where
mode _ = #const FE_TONEAREST
rounding _ = round
pi_m _ = pi
instance Rounding Trunc where
mode _ = #const FE_TOWARDZERO
rounding _ = truncate
pi_m = Round . realToFrac . unsafePerformIO . peek . pi_l
instance Rounding Up where
mode _ = #const FE_UPWARD
rounding _ = ceiling
pi_m = Round . realToFrac . unsafePerformIO . peek . pi_u
instance Rounding Down where
mode _ = #const FE_DOWNWARD
rounding _ = floor
pi_m = Round . realToFrac . unsafePerformIO . peek . pi_l
type U a = CInt -> C a -> C a
type B a = CInt -> C a -> C a -> C a
class (RealFloat a, RealExtras a, Enum a) => Precision a where
ops :: Round m a -> Ops a
lift1 :: Rounding m => (Ops a -> U a) -> Round m a -> Round m a
lift2 :: Rounding m => (Ops a -> B a) -> Round m a -> Round m a -> Round m a
data Ops a = Ops
{ pi_l :: Ptr (C a)
, pi_u :: Ptr (C a)
, padd :: B a
, pminus :: B a
, ptimes :: B a
, pdiv :: B a
, pexp :: U a
, ppow :: B a
, plog :: U a
, psqrt :: U a
, psin :: U a
, pcos :: U a
, ptan :: U a
, psinh :: U a
, pcosh :: U a
, ptanh :: U a
, pasin :: U a
, pacos :: U a
, patan :: U a
, pasinh :: U a
, pacosh :: U a
, patanh :: U a
, patan2 :: B a
, pfmod :: B a
, plog1p :: U a
, pexpm1 :: U a
, phypot :: B a
, pcbrt :: U a
, perf :: U a
}
instance Precision Double where
lift1 f r@(Round x) = Round (realToFrac (f (ops r) (mode r) (realToFrac x)))
lift2 f r@(Round x) (Round y) = Round (realToFrac (f (ops r) (mode r) (realToFrac x) (realToFrac y)))
ops _ = Ops
{ pi_l = pi_d_l
, pi_u = pi_d_u
, padd = madd
, pminus = mminus
, ptimes = mtimes
, pdiv = mdiv
, pexp = mexp
, ppow = mpow
, plog = mlog
, psqrt = msqrt
, psin = msin
, pcos = mcos
, ptan = mtan
, psinh = msinh
, pcosh = mcosh
, ptanh = mtanh
, pasin = masin
, pacos = macos
, patan = matan
, pasinh = masinh
, pacosh = macosh
, patanh = matanh
, patan2 = matan2
, pfmod = mfmod
, plog1p = mlog1p
, pexpm1 = mexpm1
, phypot = mhypot
, pcbrt = mcbrt
, perf = merf
}
instance Precision Float where
lift1 f r@(Round x) = Round (realToFrac (f (ops r) (mode r) (realToFrac x)))
lift2 f r@(Round x) (Round y) = Round (realToFrac (f (ops r) (mode r) (realToFrac x) (realToFrac y)))
ops _ = Ops
{ pi_l = pi_f_l
, pi_u = pi_f_u
, padd = maddf
, pminus = mminusf
, ptimes = mtimesf
, pdiv = mdivf
, pexp = mexpf
, ppow = mpowf
, plog = mlogf
, psqrt = msqrtf
, psin = msinf
, pcos = mcosf
, ptan = mtanf
, psinh = msinhf
, pcosh = mcoshf
, ptanh = mtanhf
, pasin = masinf
, pacos = macosf
, patan = matanf
, pasinh = masinhf
, pacosh = macoshf
, patanh = matanhf
, patan2 = matan2f
, pfmod = mfmodf
, plog1p = mlog1pf
, pexpm1 = mexpm1f
, phypot = mhypotf
, pcbrt = mcbrtf
, perf = merff
}
instance (Rounding d, Precision a) => Num (Round d a) where
fromInteger n = Round (fromInteger n)
(+) = lift2 padd
(-) = lift2 pminus
(*) = lift2 ptimes
abs (Round a) = Round (abs a)
signum (Round a) = Round (signum a)
instance (Rounding d, Precision a) => Fractional (Round d a) where
(/) = lift2 pdiv
recip = lift2 pdiv 1
fromRational = fromRat
instance (Rounding d, Precision a) => Enum (Round d a) where
succ = (+1)
pred = subtract 1
toEnum n = Round (toEnum n) -- TODO: tweak?
fromEnum (Round a) = fromEnum a -- TODO: tweak?
enumFrom = numericEnumFrom
enumFromThen = numericEnumFromThen
enumFromTo = numericEnumFromTo
enumFromThenTo = numericEnumFromThenTo
instance (Rounding d, Precision a) => Floating (Round d a) where
pi = r where r = pi_m (ops r)
exp = lift1 pexp
(**) = lift2 ppow
log = lift1 plog
sqrt = lift1 psqrt
sin = lift1 psin
cos = lift1 pcos
tan = lift1 ptan
asin = lift1 pasin
acos = lift1 pacos
atan = lift1 patan
sinh = lift1 psinh
cosh = lift1 pcosh
tanh = lift1 ptanh
asinh = lift1 pasinh
acosh = lift1 pacosh
atanh = lift1 patanh
instance (Rounding d, Precision a) => Real (Round d a) where
toRational (Round a) = toRational a -- tweak?
instance (Rounding d, Precision a) => RealFrac (Round d a) where
properFraction = properFrac
truncate (Round a) = truncate a
round (Round a) = round a
ceiling (Round a) = ceiling a
floor (Round a) = floor a
instance (Rounding d, Precision a) => RealFloat (Round d a) where
floatRadix (Round a) = floatRadix a
floatDigits (Round a) = floatDigits a
floatRange (Round a) = floatRange a
decodeFloat (Round a) = decodeFloat a
encodeFloat m e = Round (encodeFloat m e)
exponent (Round a) = exponent a
significand (Round a) = Round (significand a)
scaleFloat n (Round a) = Round (scaleFloat n a)
isNaN (Round a) = isNaN a
isInfinite (Round a) = isInfinite a
isDenormalized (Round a) = isDenormalized a
isNegativeZero (Round a) = isNegativeZero a
isIEEE (Round a) = isIEEE a
atan2 = lift2 patan2
instance (Rounding d, Precision a) => RealExtras (Round d a) where
type C (Round d a) = C a
fmod = lift2 pfmod
expm1 = lift1 pexpm1
log1p = lift1 plog1p
hypot = lift2 phypot
cbrt = lift1 pcbrt
erf = lift1 perf
-- * Fractional
properFrac
:: (Rounding dir, RealFrac a, Integral b) =>
Round dir a -> (b, Round dir a)
properFrac (Round a) = (b, Round c)
where
(b, c) = properFraction a
-- * Rounding Rationals
fromRat :: (Rounding d, Precision a) => Rational -> Round d a
fromRat (n :% 0) = case compare n 0 of
GT -> 1/0 -- +Infinity
EQ -> 0/0 -- NaN
LT -> -1/0 -- -Infinity
fromRat (n :% d) = case compare n 0 of
GT -> fromRat' (n :% d)
EQ -> encodeFloat 0 0 -- Zero
LT -> - fromRat' ((-n) :% d)
-- Conversion process:
-- Scale the rational number by the RealFloat base until
-- it lies in the range of the mantissa (as used by decodeFloat/encodeFloat).
-- Then round the rational to an Integer and encode it with the exponent
-- that we got from the scaling.
-- To speed up the scaling process we compute the log2 of the number to get
-- a first guess of the exponent.
fromRat' :: (Rounding d, Precision a) => Rational -> Round d a
-- Invariant: argument is strictly positive
fromRat' x = r
where b = floatRadix r
p = floatDigits r
(minExp0, _) = floatRange r
minExp = minExp0 - p -- the real minimum exponent
xMin = toRational (expt b (p-1))
xMax = toRational (expt b p)
p0 = (integerLogBase b (numerator x) - integerLogBase b (denominator x) - p) `max` minExp
f = if p0 < 0 then 1 % expt b (-p0) else expt b p0 % 1
(x', p') = scaleRat (toRational b) minExp xMin xMax p0 (x / f)
r = encodeFloat (rounding r x') p'
-- Scale x until xMin <= x < xMax, or p (the exponent) <= minExp.
scaleRat :: Rational -> Int -> Rational -> Rational -> Int -> Rational -> (Rational, Int)
scaleRat b minExp xMin xMax p x
| p <= minExp = (x, p)
| x >= xMax = scaleRat b minExp xMin xMax (p+1) (x/b)
| x < xMin = scaleRat b minExp xMin xMax (p-1) (x*b)
| otherwise = (x, p)
-- Exponentiation with a cache for the most common numbers.
minExpt, maxExpt :: Int
minExpt = 0
maxExpt = 1100
expt :: Integer -> Int -> Integer
expt base n | base == 2 && n >= minExpt && n <= maxExpt = expts ! n
| otherwise = base^n
expts :: Array Int Integer
expts = array (minExpt,maxExpt) [(n,2^n) | n <- [minExpt .. maxExpt]]
-- Compute the (floor of the) log of i in base b.
-- Simplest way would be just divide i by b until it's smaller then b, but that would
-- be very slow! We are just slightly more clever.
integerLogBase :: Integer -> Integer -> Int
integerLogBase b i
| i < b = 0
| otherwise = doDiv (i `div` (b^l)) l
where
-- Try squaring the base first to cut down the number of divisions.
l = 2 * integerLogBase (b*b) i
doDiv :: Integer -> Int -> Int
doDiv x y
| x < b = y
| otherwise = doDiv (x `div` b) (y+1)
up :: a -> Round Up a
up = Round
{-# INLINE up #-}
down :: a -> Round Down a
down = Round
{-# INLINE down #-}
trunc :: a -> Round Trunc a
trunc = Round
{-# INLINE trunc #-}
runUp :: Round Up a -> a
runUp (Round a) = a
{-# INLINE runUp #-}
runDown :: Round Down a -> a
runDown (Round a) = a
{-# INLINE runDown #-}
runTrunc :: Round Trunc a -> a
runTrunc (Round a) = a
{-# INLINE runTrunc #-}
foreign import ccall "rounding.h &pi_d_l"
pi_d_l :: Ptr CDouble
foreign import ccall "rounding.h &pi_d_u"
pi_d_u :: Ptr CDouble
foreign import ccall unsafe "rounding.h madd"
madd :: CInt -> CDouble -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mminus"
mminus :: CInt -> CDouble -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mtimes"
mtimes :: CInt -> CDouble -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mdiv"
mdiv :: CInt -> CDouble -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mexp"
mexp :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mpow"
mpow :: CInt -> CDouble -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mlog"
mlog :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h msqrt"
msqrt :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h msin"
msin :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mcos"
mcos :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mtan"
mtan :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h msinh"
msinh :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mcosh"
mcosh :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mtanh"
mtanh :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h masin"
masin :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h macos"
macos :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h matan"
matan :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h masinh"
masinh :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h macosh"
macosh :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h matanh"
matanh :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h matan2"
matan2 :: CInt -> CDouble -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mfmod"
mfmod :: CInt -> CDouble -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mlog1p"
mlog1p :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mexpm1"
mexpm1 :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mhypot"
mhypot :: CInt -> CDouble -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h mcbrt"
mcbrt :: CInt -> CDouble -> CDouble
foreign import ccall unsafe "rounding.h merf"
merf :: CInt -> CDouble -> CDouble
foreign import ccall "rounding.h &pi_f_l"
pi_f_l :: Ptr CFloat
foreign import ccall "rounding.h &pi_f_u"
pi_f_u :: Ptr CFloat
foreign import ccall unsafe "rounding.h madd"
maddf :: CInt -> CFloat -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mminus"
mminusf :: CInt -> CFloat -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mtimesf"
mtimesf :: CInt -> CFloat -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mdivf"
mdivf :: CInt -> CFloat -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mexpf"
mexpf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mpowf"
mpowf :: CInt -> CFloat -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mlogf"
mlogf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h msqrtf"
msqrtf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h msinf"
msinf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mcosf"
mcosf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mtanf"
mtanf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h msinhf"
msinhf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mcoshf"
mcoshf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mtanhf"
mtanhf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h masinf"
masinf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h macosf"
macosf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h matanf"
matanf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h masinhf"
masinhf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h macoshf"
macoshf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h matanhf"
matanhf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h matan2f"
matan2f :: CInt -> CFloat -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h matan2f"
mfmodf :: CInt -> CFloat -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mlog1pf"
mlog1pf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mexpm1f"
mexpm1f :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mhypotf"
mhypotf :: CInt -> CFloat -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h mcbrtf"
mcbrtf :: CInt -> CFloat -> CFloat
foreign import ccall unsafe "rounding.h merff"
merff :: CInt -> CFloat -> CFloat