```{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Rank2Types #-}

-- | A reasonably efficient implementation of arbitrary-but-fixed
-- precision real numbers. This is inspired by, and partly based on,
-- "Data.Number.Fixed" and "Data.Number.CReal", but more efficient.

module Data.Number.FixedPrec (
-- * Type-level integers for precision
Precision,
P0, P1, P10, P100, P1000, P2000,
PPlus1, PPlus3, PPlus10, PPlus100, PPlus1000,

-- * Fixed-precision numbers
FixedPrec,
getprec,

-- * Static and dynamic casts
cast,
upcast,
downcast,

-- * Other operations
fractional,
log_double
) where

import Data.Ratio
import System.Random

-- ----------------------------------------------------------------------
-- * Auxiliary functions

-- ----------------------------------------------------------------------
-- ** Integer functions

-- | Integer division with rounding to the closest. Note: rounding
-- could be improved. Right now, we always round up in case of a tie.
divi :: Integer -> Integer -> Integer
divi a b = (a + (b `div` 2)) `div` b
infixl 7 `divi`

-- | Shift the integer to the right by the given number of decimal
-- digits, with rounding.
decshiftR :: Int -> Integer -> Integer
decshiftR n x = x `divi` 10^n where

-- | Shift the integer to the right by the given number of decimal
-- digits, without rounding (i.e., truncate)
dectruncR :: Int -> Integer -> Integer
dectruncR n x = x `quot` 10^n where

-- | Shift the integer to the left by the given number of decimal
-- digits.
decshiftL :: Int -> Integer -> Integer
decshiftL n x = x * 10^n

-- | Return 1 + the position of the leftmost \"1\" bit of a
-- non-negative 'Integer'. Do this in time O(/n/ log /n/), where /n/
-- is the size of the integer (in digits).
hibit :: Integer -> Int
hibit 0 = 0
hibit n = aux 1 where
aux k
| n >= 2^k  = aux (2*k)
| otherwise = aux2 k (k `div` 2)    -- 2^(k/2) <= n < 2^k
aux2 upper lower
| upper - lower < 2  = upper
| n >= 2^middle = aux2 upper middle
| otherwise = aux2 middle lower
where
middle = (upper + lower) `div` 2

-- | For /n/ ≥ 0, return the floor of the square root of /n/. This is
-- done using integer arithmetic, so there are no rounding errors.
intsqrt :: (Integral n) => n -> n
intsqrt n
| n <= 0 = 0
| otherwise = iterate a
where
iterate m
| m_sq <= n && m_sq + 2*m + 1 > n = m
| otherwise = iterate ((m + n `div` m) `div` 2)
where
m_sq = m*m
a = 2^(b `div` 2)
b = hibit (fromIntegral n)

-- ----------------------------------------------------------------------
-- ** Other general-purpose functions

-- | Given positive /b/ > 1 and /x/ > 0, return (/n/, /r/) such that
--
-- * /x/ = /r/ /b/[sup /n/] and
--
-- * 1 ≤ /r/ < /b/.
--
-- In other words, let /n/ = ⌊log[sub /b/] /x/⌋ and
-- /r/ = /x/ /b/[sup −/n/]. This can be more efficient than 'floor'
-- ('logBase' /b/ /x/) depending on the type; moreover, it also works
-- for exact types such as 'Rational' and 'EReal'.
floorlog :: (Fractional b, Ord b) => b -> b -> (Integer, b)
floorlog b x
| x <= 0            = error "floorlog: argument not positive"
| 1 <= x && x < b   = (0, x)
| 1 <= x*b && x < 1 = (-1, b*x)
| r < b             = (2*n, r)
| otherwise         = (2*n+1, r/b)
where
(n, r) = floorlog (b^2) x

-- | A version of the natural logarithm that returns a 'Double'. The
-- logarithm of just about any value can fit into a 'Double'; so if
-- not a lot of precision is required in the mantissa, this function
-- is often faster than 'log'.
log_double :: (Floating a, Real a) => a -> Double
log_double x = y where
e = exp 1
(n, r) = floorlog e x
y = fromInteger n + log (to_double r)
to_double = fromRational . toRational

-- ----------------------------------------------------------------------
-- * Type-level integers for precision

-- | A type class for type-level integers, capturing a precision
-- parameter. Precision is measured in decimal digits.
class Precision e where
-- | Get the precision, in decimal digits.
digits :: e -> Int

-- | Precision of 0 digits.
data P0
instance Precision P0 where
digits e = 0

-- | Precision of 1 digit.
data P1
instance Precision P1 where
digits e = 1

-- | Precision of 10 digits.
data P10
instance Precision P10 where
digits e = 10

-- | Precision of 100 digits.
data P100
instance Precision P100 where
digits e = 100

-- | Precision of 1000 digits.
data P1000
instance Precision P1000 where
digits e = 1000

-- | Precision of 2000 digits.
data P2000
instance Precision P2000 where
digits e = 2000

-- | Add 1 digit to the given precision.
data PPlus1 e
instance Precision e => Precision (PPlus1 e) where
digits e = digits (un e) + 1 where
un :: PPlus1 e -> e
un = undefined

-- | Add 3 digits to the given precision.
data PPlus3 e
instance Precision e => Precision (PPlus3 e) where
digits e = digits (un e) + 3 where
un :: PPlus3 e -> e
un = undefined

-- | Add 10 digits to the given precision.
data PPlus10 e
instance Precision e => Precision (PPlus10 e) where
digits e = digits (un e) + 10 where
un :: PPlus10 e -> e
un = undefined

-- | Add 100 digits to the given precision.
data PPlus100 e
instance Precision e => Precision (PPlus100 e) where
digits e = digits (un e) + 100 where
un :: PPlus100 e -> e
un = undefined

-- | Add 1000 digits to the given precision.
data PPlus1000 e
instance Precision e => Precision (PPlus1000 e) where
digits e = digits (un e) + 1000 where
un :: PPlus1000 e -> e
un = undefined

----------------------------------------------------------------------
-- * Fixed-precision numbers

-- \$ Fixed-precision numbers are simply implemented as integers. The
-- integer /n/ represents the real number /n/⋅10[sup −/d/], where /d/
-- is the precision in digits.

-- | The type of fixed-precision numbers.
newtype FixedPrec e = F Integer
deriving (Eq, Ord)

-- | Get the precision of a fixed-precision number, in decimal digits.
getprec :: (Precision e) => FixedPrec e -> Int
getprec = digits . un where
un :: FixedPrec e -> e
un = undefined

-- ----------------------------------------------------------------------
-- ** Static and dynamic casts

-- | Cast from any 'FixedPrec' type to another.
cast :: (Precision e, Precision f) => FixedPrec e -> FixedPrec f
cast a@(F x) = b where
b = F y
px = getprec a
py = getprec b
y = if (px >= py) then
decshiftR (px - py) x
else
decshiftL (py - px) x

-- | Cast to a fixed-point type with three additional digits of accuracy.
upcast :: (Precision e) => FixedPrec e -> FixedPrec (PPlus3 e)
upcast = cast

-- | Cast to a fixed-point type with three fewer digits of accuracy.
downcast :: (Precision e) => FixedPrec (PPlus3 e) -> FixedPrec e
downcast = cast

-- /d/ digits of accuracy to /x/ during the computation.
with_added_digits :: forall a f.(Precision f) => Int -> (forall e.(Precision e) => FixedPrec e -> a) -> FixedPrec f -> a
with_added_digits d f x = loop d (un x)
where
un :: FixedPrec e -> e
un = undefined
loop :: forall e.(Precision e) => Int -> e -> a
loop d e
| d >= 1000 = loop (d-1000) (undefined :: PPlus1000 e)
| d >= 100  = loop (d-100)  (undefined :: PPlus100 e)
| d >= 10   = loop (d-10)   (undefined :: PPlus10 e)
| d > 0     = loop (d-1)    (undefined :: PPlus1 e)
| otherwise = f (cast x :: FixedPrec e)

-- ----------------------------------------------------------------------
-- ** Some primitive operations

-- | Multiply an integer by a fixed-precision number. This is
-- marginally more efficient than multiplying two fixed-precision
-- numbers.
(..*) :: (Precision e) => Integer -> FixedPrec e -> FixedPrec e
n ..* (F x) = F (n * x)
infixl 7 ..*

-- | Divide a fixed-precision number by an integer. This is marginally
-- more efficient than dividing two fixed-precision numbers.
(/..) :: (Precision e) => FixedPrec e -> Integer -> FixedPrec e
(F x) /.. n = F (x `divi` n)
infixl 7 /..

-- | Return the positive fractional part of a fixed-precision
-- number. The result is always in [0,1), regardless of the sign of
-- the input.
fractional :: (Precision e) => FixedPrec e -> FixedPrec e
fractional a@(F x) = F (x `mod` one) where
p = getprec a
one = (decshiftL p 1)

-- ----------------------------------------------------------------------
-- ** Power series

-- | Define a list of rational numbers (i.e., the coefficients of a
-- power series) from a recursive formula.
accs :: (Rational -> Integer -> Rational) -> [Rational]
accs f = scanl f 1 [1..]

-- | The power series stops when the last term is smaller than the
-- precision. This is accurate for alternating and decreasing series,
-- and provided |/x/| ≤ 1.
powerseries :: (Precision e) => [Rational] -> FixedPrec e -> FixedPrec e
powerseries [] x = 0
powerseries (h:t) x
-- we could improve upon this by checking that h' * x^n < 1.
| h' == 0   = a
| otherwise = a + x * powerseries t x
where
a@(F h') = fromRational h

-- ----------------------------------------------------------------------
-- ** Limited domain implementations

-- \$ The following are implementations of various analytic functions
-- by power series. These implementations have limited domain, and do
-- not compensate for round-off errors.

-- | The Taylor series for sin /x/, centered at 0. This implementation
-- works for |/x/| ≤ 1.
sin_p :: (Precision e) => FixedPrec e -> FixedPrec e
sin_p x = x * powerseries (accs (\a n -> -a * (1 % (2*n*(2*n+1))))) (x^2)

-- | The Taylor series for cos /x/, centered at 0. This implementation
-- works for |/x/| ≤ 1.
cos_p :: (Precision e) => FixedPrec e -> FixedPrec e
cos_p x = powerseries (accs (\a n -> -a * (1 % (2*n*(2*n-1))))) (x^2)

-- | The Taylor series for [exp /x/], centered at 0. This
-- implementation works for |/x/| ≤ 1.
exp_p :: (Precision e) => FixedPrec e -> FixedPrec e
exp_p x = powerseries (accs (\a n -> a * (1 % n))) x

-- | The Taylor series for log /x/, centered at 1. This
-- implementation works for |/x/ − 1| ≤ 1/4.
log_p :: (Precision e) => FixedPrec e -> FixedPrec e
log_p x = (x-1) * powerseries [ 1 % ((-4)^n * (n+1)) | n <- [0..] ] (4*(x-1))

-- | The Taylor series for atan /x/, centered at 0. This
-- implementation works for |/x/| ≤ 0.44.
atan_p :: (Precision e) => FixedPrec e -> FixedPrec e
atan_p x = x * powerseries [ 1 % ((-5)^n * (2*n+1)) | n <- [0..]] (5*x*x)

-- | The Taylor series for atan /x/, centered at 0. This
-- implementation works for |/x/| ≤ 0.2, and is faster, in that range,
-- than 'atan_p'.
atan_p2 :: (Precision e) => FixedPrec e -> FixedPrec e
atan_p2 x = x * powerseries [ 1 % ((-25)^n * (2*n+1)) | n <- [0..]] (25*x*x)

-- | The Taylor series for atan /x/, centered at 0. This
-- implementation works for |/x/| ≤ 1/239, and is faster, in that
-- range, than 'atan_p2'.
atan_p3 :: (Precision e) => FixedPrec e -> FixedPrec e
atan_p3 x = x * powerseries [ 1 % ((-57121)^n * (2*n+1)) | n <- [0..]] (57121*x*x)

-- ----------------------------------------------------------------------
-- ** Raw versions of analytic functions

-- \$ The following functions are \"raw\", in the sense that they do
-- not try to compensate for accumulated round-off errors. They must
-- all be wrapped in 'with_added_digits', or 'upcast' and 'downcast',
-- to produce more accurate versions.
--
-- Each function is defined on its natural domain.

-- | Raw implementation of the sine function.
sin_raw :: (Precision e) => FixedPrec e -> FixedPrec e
sin_raw x
| -1 <= x && x < 1 = sin_p x -- bypass slow domain reduction
| m == 0    = sin_p x'
| m == 1    = cos_p x'
| m == 2    = -sin_p x'
| otherwise = -cos_p x'
where
n = round (x / p2)
m = n `mod` 4
x' = x - n ..* p2
p2 = pi /.. 2

-- | Raw implementation of the cosine function.
cos_raw :: (Precision e) => FixedPrec e -> FixedPrec e
cos_raw x
| -1 <= x && x < 1 = cos_p x -- bypass slow domain reduction
| m == 0    = cos_p x'
| m == 1    = -sin_p x'
| m == 2    = -cos_p x'
| otherwise = sin_p x'
where
n = round (x / p2)
m = n `mod` 4
x' = x - n ..* p2
p2 = pi /.. 2

-- | Raw implementation of the exponential function. Note: the loss of
-- precision is much more substantial than that of the other raw
-- functions in this section. This is due to the multiplication of
-- fixed-precision values by numbers much larger than 1.
exp_raw :: (Precision e) => FixedPrec e -> FixedPrec e
exp_raw x
| -1 <= x && x <= 1  = exp_p x
| otherwise = exp_raw (x/2) ^2

-- | Raw implementation of the natural logarithm.
log_raw :: (Precision e) => FixedPrec e -> FixedPrec e
log_raw x
| x <= 0 = error "log: argument out of range"
| 0.75 <= x && x <= 1.25 = log_p x
| x > 3.5 = fromInteger n + log r
| x > 1 = 0.5 + log (x / e2)
| otherwise = - log (1 / x)
where
e2 = exp_p 0.5
e = exp_p 1
(n, r) = floorlog e x

-- | Raw implementation of the power function. This is subject to
-- similar loss of precision as the 'exp_raw' function.
power_raw :: (Precision e) => FixedPrec e -> FixedPrec e -> FixedPrec e
power_raw x y = exp_raw (log_raw x * y)

-- | Raw implementation of the 'logBase' function. This is subject to
-- similar loss of precision as the 'exp_raw' function.
logBase_raw :: (Precision e) => FixedPrec e -> FixedPrec e -> FixedPrec e
logBase_raw x y = log y / log x

-- | Raw implementation of the square root.
sqrt_raw :: (Precision e) => FixedPrec e -> FixedPrec e
sqrt_raw a@(F x)
| a >= 0  = F y
| otherwise = error "sqrt: argument out of range"
where
p = getprec a
y = intsqrt (x * 10^p)

-- | Raw implementation of the inverse tangent.
atan_raw :: (Precision e) => FixedPrec e -> FixedPrec e
atan_raw x
| -0.44 <= x && x <= 0.44 = atan_p x
| x < 0     = -atan (-x)
| x >= 2.27  = p2 - atan_p (1/x)
| otherwise = p4 + atan_p ((x-1)/(x+1))
where
p2 = pi /.. 2
p4 = pi /.. 4

-- | Raw implementation of π.
pi_raw :: (Precision e) => FixedPrec e
pi_raw = 16 ..* atan_p2 (1/5) - 4 ..* atan_p3 (1/239)

-- | Raw implementation of the inverse sine function.
asin_raw :: (Precision e) => FixedPrec e -> FixedPrec e
asin_raw x
| -0.7 <= x && x <= 0.7 = atan (x / cos)
| x > 0 && x <= 1   = p2 - atan (cos / x)
| x < 0 && x >= -1  = -p2 - atan (cos / x)
| otherwise = error "asin: argument out of range"
where
cos = sqrt(1 - x^2)
p2 = pi /.. 2

-- | Raw implementation of the inverse cosine function.
acos_raw :: (Precision e) => FixedPrec e -> FixedPrec e
acos_raw x
| -0.7 <= x && x <= 0.7 = p2 - atan (x / sin)
| x > 0 && x <= 1   = atan (sin / x)
| x < 0 && x >= -1  = pi + atan (sin / x)
| otherwise = error "acos: argument out of range"
where
sin = sqrt(1 - x^2)
p2 = pi /.. 2

-- | Raw implementation of the hyperbolic sine.
sinh_raw :: (Precision e) => FixedPrec e -> FixedPrec e
sinh_raw x = (e - 1/e) /.. 2 where e = exp x

-- | Raw implementation of the hyperbolic cosine.
cosh_raw :: (Precision e) => FixedPrec e -> FixedPrec e
cosh_raw x = (e + 1/e) /.. 2 where e = exp x

-- | Raw implementation of the inverse hyperbolic tangent.
atanh_raw :: (Precision e) => FixedPrec e -> FixedPrec e
atanh_raw x = log ((1+x) / (1-x)) /.. 2

-- | Raw implementation of the inverse hyperbolic sine.
asinh_raw :: (Precision e) => FixedPrec e -> FixedPrec e
asinh_raw x = log (x + sqrt (x^2+1))

-- | Raw implementation of the inverse hyperbolic cosine.
acosh_raw :: (Precision e) => FixedPrec e -> FixedPrec e
acosh_raw x
| x >= 1 = log (x + sqrt (x^2-1))
| otherwise = error "acosh: argument out of range"

-- ----------------------------------------------------------------------
-- Instance declarations

instance (Precision e) => Show (FixedPrec e) where
show a@(F x) = sign ++ integral ++ "." ++ fractional where
x' = abs x
sign = if x < 0 then "-" else ""
integral = show (dectruncR p x')
fractional' = show \$ x' `mod` (decshiftL p 1)
fractional = pad_to_length p '0' fractional'
p = getprec a
pad_to_length p c l = replicate (p - length l) c ++ l

instance (Precision e) => Num (FixedPrec e) where
F x + F y = F (x+y)
a@(F x) * F y = F (decshiftR (getprec a) (x*y))
F x - F y = F (x-y)
negate (F x) = F (negate x)
abs (F x) = F (abs x)
signum (F x) = fromInteger (signum x)
fromInteger x = y where
y = F (decshiftL p x) where
p = getprec y

instance (Precision e) => Fractional (FixedPrec e) where
a@(F x) / F y = F ((10^p * x) `divi` y) where
p = getprec a
fromRational r = fromInteger num / fromInteger denom where
num = numerator r
denom = denominator r

instance (Precision e) => Real (FixedPrec e) where
toRational a@(F x) = x % one where
p = getprec a
one = (decshiftL p 1)

instance (Precision e) => RealFrac (FixedPrec e) where
properFraction a@(F x) = (fromInteger n, F y) where
p = getprec a
y = x `rem` one
n = x `quot` one
one = (decshiftL p 1)

instance (Precision e) => Floating (FixedPrec e) where
pi = downcast pi_raw
sin = downcast . sin_raw . upcast
cos = downcast . cos_raw . upcast
log = downcast . log_raw . upcast
sqrt = downcast . sqrt_raw . upcast
atan = downcast . atan_raw . upcast
asin = downcast . asin_raw . upcast
acos = downcast . acos_raw . upcast
sinh = downcast . sinh_raw . upcast
cosh = downcast . cosh_raw . upcast
atanh = downcast . atanh_raw . upcast
asinh = downcast . asinh_raw . upcast
acosh = downcast . acosh_raw . upcast

exp x
| x <= 1  = exp_raw x
| otherwise = with_added_digits d (cast . exp_raw) x
where
-- we need to add digits to the internal calculation, because
-- exp_raw multiplies numbers much larger than 1.
d = 1 + ceiling (x * 0.45)

x ** y
| x <= 1  = power_raw x y
| otherwise = with_added_digits d (cast . (power_raw (cast x))) y
where
-- we don't need a lot of precision in the logarithm here,
-- because it is only to determine the number of digits
d = 1 + ceiling (0.45 * y * cast (log_raw (cast x :: FixedPrec P10)))

logBase x y
| (x < 0.36 || x > 2.72) && lo < y && y < hi = downcast (logBase_raw (upcast x) (upcast y))
| otherwise = with_added_digits d (cast . (logBase_raw (cast x))) y
where
dx = ceiling (-0.45 * log (abs (log_double x)))
dy = ceiling (0.45 * log (abs (log_double y)))
d = max dx (2*dx + dy)
lo = 10000000000
hi = 0.0000000001

instance Precision e => Random (FixedPrec e) where
randomR (lo, hi) g = (x, g1) where
n = getprec x  -- precision in decimal digits
lo_n = floor (lo * 10^n)
hi_n = floor (hi * 10^n)
(x_n, g1) = randomR (lo_n, hi_n) g
x = 0.1^n * fromInteger x_n

random = randomR (0, 1)
```