{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PostfixOperators #-}

-----------------------------------------------------------------------------
-- | This module exports a bunch of utilities for working inside the CReal
-- datatype. One should be careful to maintain the CReal invariant when using
-- these functions
----------------------------------------------------------------------------
module Data.CReal.Internal
  ( CReal(..)
  , atPrecision
  , crealPrecision

  , expBounded
  , logBounded

  , atanBounded
  , sinBounded
  , cosBounded

  , shiftL
  , shiftR

  , powerSeries
  , alternateSign

  , (/.)
  , log2
  , log10
  , isqrt

  , showAtPrecision
  , decimalDigitsAtPrecision
  , rationalToDecimal
  ) where

import Data.List (scanl')
import Data.Ratio (numerator,denominator,(%))
import GHC.Base (Int(..))
import GHC.Integer.Logarithms (integerLog2#, integerLogBase#)
import GHC.TypeLits

-- $setup
-- >>> :set -XDataKinds

infixl 7 /.

default ()

-- | The type CReal represents a fast binary Cauchy sequence. This is
-- a Cauchy sequence with the invariant that the pth element will be within
-- 2^-p of the true value. Internally this sequence is represented as
-- a function from Ints to Integers.
newtype CReal (n :: Nat) = CR (Int -> Integer)

-- | crealPrecision x returns the type level parameter representing x's default
-- precision.
--
-- >>> crealPrecision (1 :: CReal 10)
-- 10
crealPrecision :: KnownNat n => CReal n -> Int
crealPrecision = fromInteger . natVal

-- | @x \`atPrecision\` p@ returns the numerator of the pth element in the
-- Cauchy sequence represented by x. The denominator is 2^p.
--
-- >>> 10 `atPrecision` 10
-- 10240
atPrecision :: CReal n -> Int -> Integer
(CR x) `atPrecision` p = x p

-- | A CReal with precision p is shown as a decimal number d such that d is
-- within 2^-p of the true value.
--
-- >>> show (47176870 :: CReal 0)
-- "47176870"
instance KnownNat n => Show (CReal n) where
  show x = showAtPrecision (crealPrecision x) x

-- | @signum (x :: CReal p)@ returns the sign of @x@ at precision @p@. It's
-- important to remember that this /may not/ represent the actual sign of @x@ if
-- the distance between @x@ and zero is less than 2^-@p@.
--
-- This is a little bit of a fudge, but it's probably better than failing to
-- terminate when trying to find the sign of zero. The class still respects the
-- abs-signum law though.
--
-- >>> signum (0.1 :: CReal 2)
-- 0.0
--
-- >>> signum (0.1 :: CReal 3)
-- 1.0
instance Num (CReal n) where
  fromInteger i = CR (\p -> i * 2 ^ p)

  negate (CR x) = CR (negate . x)

  abs (CR x) = CR (abs . x)

  {-# INLINE (+) #-}
  CR x1 + CR x2 = CR (\p -> let n1 = x1 (p + 2)
                                n2 = x2 (p + 2)
                            in (n1 + n2) /. 4)

  {-# INLINE (*) #-}
  CR x1 * CR x2 = CR (\p -> let s1 = log2 (abs (x1 0) + 2) + 3
                                s2 = log2 (abs (x2 0) + 2) + 3
                                n1 = x1 (p + s2)
                                n2 = x2 (p + s1)
                            in (n1 * n2) /. 2^(p + s1 + s2)  )

  signum x = CR (\p -> signum (x `atPrecision` p) * 2^p)

-- | Taking the reciprocal of zero will not terminate
instance Fractional (CReal n) where
  -- This should be in base
  fromRational n = fromInteger (numerator n) / fromInteger (denominator n)

  {-# INLINE recip #-}
  -- TODO: Make recip 0 throw an error (if, for example, it would take more
  -- than 4GB of memory to represent the result)
  recip (CR x) = CR (\p -> let s = findFirstMonotonic ((3 <=) . abs . x)
                               n = x (p + 2 * s + 2)
                           in 2^(2 * p + 2 * s + 2) /. n)

instance Floating (CReal n) where
  -- TODO: Could we use something faster such as Ramanujan's formula
  pi = 4 * piBy4

  exp x = let CR o = x / ln2
              l = o 0
              y = x - fromInteger l * ln2
          in if l == 0
               then expBounded x
               else expBounded y `shiftL` fromInteger l

  -- | Range reduction on the principle that ln (a * b) = ln a + ln b
  log x = let CR o = x
              l = log2 (o 2) - 2
              a = x `shiftR` l
          in if | l < 0  -> - log (recip x)
                | l == 0 -> logBounded x
                | l > 0  -> logBounded a + fromIntegral l * ln2

  sqrt (CR x) = CR (\p -> let n = x (2 * p)
                          in isqrt n)

  -- | This will diverge when the base is not positive
  x ** y = exp (log x * y)

  logBase x y = log y / log x

  sin x = cos (x - pi / 2)

  cos x = let CR o = x / piBy4
              s = o 1 /. 2
              octant = fromInteger $ s `mod` 8
              offset = x - (fromIntegral s * piBy4)
              fs = [          cosBounded
                   , negate . sinBounded . subtract piBy4
                   , negate . sinBounded
                   , negate . cosBounded . (piBy4-)
                   , negate . cosBounded
                   ,          sinBounded . subtract piBy4
                   ,          sinBounded
                   ,          cosBounded . (piBy4-)]
          in (fs !! octant) offset

  -- TODO: use multiplyBounded here
  tan x = sin x / cos x

  asin x = 2 * atan (x / (1 + sqrt (1 - x*x)))

  acos x = pi/2 - asin x

  atan x = let -- q is x to the nearest 1/4
               q = x `atPrecision` 2
           in if | q <  -4 -> atanBounded (negate (recip x)) - pi / 2
                 | q == -4 -> -pi / 4 - atanBounded ((x + 1) / (x - 1))
                 | q ==  4 -> pi / 4 + atanBounded ((x - 1) / (x + 1))
                 | q >   4 -> pi / 2 - atanBounded (recip x)
                 | otherwise -> atanBounded x

  -- TODO: benchmark replacing these with their series expansion
  sinh x = (exp x - exp (-x)) / 2
  cosh x = (exp x + exp (-x)) / 2
  tanh x = let e2x = exp (2 * x)
           in (e2x - 1) / (e2x + 1)

  asinh x = log (x + sqrt (x * x + 1))
  acosh x = log (x + sqrt (x + 1) * sqrt (x - 1))
  atanh x = (log (1 + x) - log (1 - x)) / 2

-- | 'toRational' returns the CReal n evaluated at a precision of 2^-n
instance KnownNat n => Real (CReal n) where
  toRational x = let p = crealPrecision x
                 in x `atPrecision` p % 2^p

-- | Values of type @CReal p@ are compared for equality at precision @p@. This
-- may cause values which differ by less than 2^-p to compare as equal.
--
-- >>> 0 == (0.1 :: CReal 1)
-- True
instance KnownNat n => Eq (CReal n) where
  -- TODO, should this try smaller values first?
  x == y = let p = crealPrecision x
           in (x - y) `atPrecision` p == 0

-- | Like equality values of type @CReal p@ are compared at precision @p@.
instance KnownNat n => Ord (CReal n) where
  compare x y = let p = crealPrecision x
                in compare ((x - y) `atPrecision` p) 0
  max (CR x) (CR y) = CR (\p -> max (x p) (y p))
  min (CR x) (CR y) = CR (\p -> min (x p) (y p))

--------------------------------------------------------------------------------
-- Some utility functions
--------------------------------------------------------------------------------

--
-- Constants
--

piBy4 :: CReal n
piBy4 = 4 * atanBounded (1/5) - atanBounded (1 / 239) -- Machin Formula

ln2 :: CReal n
ln2 = logBounded 2

--
-- Bounded exponential functions
--

-- | The input to expBounded must be in the range (-1..1)
expBounded :: CReal n -> CReal n
expBounded x = let q = [1 % (n!) | n <- [0..]]
               in powerSeries q (max 5) x

-- | The input must be in [1..2]
logBounded :: CReal n -> CReal n
logBounded x = let q = [1 % n | n <- [1..]]
                   y = (x - 1) / x
               in y * powerSeries q (*2) y

--
-- Bounded trigonometric functions
--

-- | The input to sinBounded must be in (-1..1)
sinBounded :: CReal n -> CReal n
sinBounded x = let q = alternateSign (scanl' (*) 1 [ 1 % (n*(n+1)) | n <- [2,4..]])
               in x * powerSeries q (max 1) (x*x)

-- | The input to cosBounded must be in (-1..1)
cosBounded :: CReal n -> CReal n
cosBounded x = let q = alternateSign (scanl' (*) 1 [1 % (n*(n+1)) | n <- [1,3..]])
               in powerSeries q (max 1) (x*x)

-- | The input to atanBounded must be in [-1..1]
atanBounded :: CReal n -> CReal n
atanBounded x = let q = scanl' (*) 1 [n % (n + 1) | n <- [2,4..]]
                    d = 1 + x * x
                in CR (\p -> ((x/d) * powerSeries q (+1) (x*x/d)) `atPrecision` p)

--
-- Multiplication with powers of two
--

-- | @x \`shiftR\` n@ is equal to @x@ divided by 2^@n@
--
-- @n@ can be negative or zero
--
-- This can be faster than doing the division
shiftR :: CReal n -> Int -> CReal n
shiftR (CR x) n = CR (\p -> let p' = p - n
                            in if p' >= 0
                                 then x p'
                                 else x 0 /. 2^(-p'))

-- | @x \`shiftL\` n@ is equal to @x@ multiplied by 2^@n@
--
-- @n@ can be negative or zero
--
-- This can be faster than doing the multiplication
shiftL :: CReal n -> Int -> CReal n
shiftL x = shiftR x . negate


--
-- Showing CReals
--

-- | Return a string representing a decimal number within 2^-p of the value
-- represented by the given @CReal p@.
showAtPrecision :: Int -> CReal n -> String
showAtPrecision p (CR x) = let places = decimalDigitsAtPrecision p
                               r = x p % 2^p
                           in rationalToDecimal places r

-- | How many decimal digits are required to represent a number to within 2^-p
decimalDigitsAtPrecision :: Int -> Int
decimalDigitsAtPrecision 0 = 0
decimalDigitsAtPrecision p = log10 (2^p) + 1

-- | @rationalToDecimal p x@ returns a string representing @x@ at @p@ decimal
-- places.
rationalToDecimal :: Int -> Rational -> String
rationalToDecimal places r = p ++ is ++ if places > 0 then "." ++ fs else ""
  where r' = abs r
        p = case signum r of
              -1 -> "-"
              _  -> ""
        ds = show ((numerator r' * 10^places) /. denominator r')
        l = length ds
        (is, fs) = if | l <= places -> ("0", replicate (places - l) '0' ++ ds)
                      | otherwise -> splitAt (length ds - places) ds


--
-- Integer operations
--

-- | Division rounding to the nearest integer and rounding half integers to the
-- nearest even integer.
(/.) :: Integer -> Integer -> Integer
n /. d = round (n % d)

-- | @log2 x@ returns the base 2 logarithm of @x@ rounded towards zero.
log2 :: Integer -> Int
log2 x = I# (integerLog2# x)

-- | @log10 x@ returns the base 10 logarithm of @x@ rounded towards zero.
log10 :: Integer -> Int
log10 x = I# (integerLogBase# 10 x)

-- | @isqrt x@ returns the square root of @x@ rounded towards zero.
isqrt :: Integer -> Integer
isqrt x | x < 0     = error "Sqrt applied to negative Integer"
        | x == 0    = 0
        | otherwise = until satisfied improve initialGuess
  where improve r    = (r + (x `div` r)) `div` 2
        satisfied r  = sq r <= x && sq (r + 1) > x
        initialGuess = 2 ^ (log2 x `div` 2)
        sq r         = r * r

-- | Factorial function
(!) :: Integer -> Integer
(!) x = product [2..x]

--
-- Searching
--

-- | Given a monotonic function
findFirstMonotonic :: (Int -> Bool) -> Int
findFirstMonotonic p = binarySearch l' u'
  where (l', u') = findBounds 0 1
        findBounds l u = if p u then (l, u)
                                else findBounds u (u*2)
        binarySearch l u = let m = l + ((u - l) `div` 2)
                           in if | l+1 == u  -> l
                                 | p m       -> binarySearch l m
                                 | otherwise -> binarySearch m u


--
-- Power series
--

-- | Apply 'negate' to every other element, starting with the second
--
-- >>> alternateSign [1..5]
-- [1,-2,3,-4,5]
alternateSign :: Num a => [a] -> [a]
alternateSign = zipWith ($) (cycle [id, negate])

-- | @powerSeries q f x `atPrecision` p@ will evaluate the power series with
-- coefficients @q@ at precision @f p@ at @x@
--
-- @f@ should be a function such that the CReal invariant is maintained
--
-- See any of the trig functions for an example
powerSeries :: [Rational] -> (Int -> Int) -> CReal n -> CReal n
powerSeries q termsAtPrecision (CR x) =
  CR (\p -> let t = termsAtPrecision p
                d = log2 (toInteger t) + 2
                p' = p + d
                p'' = p' + d
                m = x p''
                xs = (%1) <$> iterate (\e -> m * e /. 2^p'') (2^p')
                r = sum . take (t + 1) . fmap (round . (* (2^d))) $ zipWith (*) q xs
            in r /. 4^d)