{-# INCLUDE <mpfr.h> #-}
{-# INCLUDE <hsmpfr.h> #-}

module Data.Number.Ball (Ball(..), makeA, make, 
             normalizeBall,
             lower, upper, lower_, upper_,
             sgnLower, sgnUpper,
             width, compareB,
             below, contains, 
             intersectA, intersect,
             add, sub, neg, absB, mul, div, sqrt, exp, log,
             maxB, minB,
             fromDyadic, fromString, fromInt, fromWord )
where

import qualified Data.Number.Dyadic as D
import Data.Order

import Prelude hiding (div, sqrt, exp, log)
import Data.Word(Word)

-- | Ball represents a closed interval @[center-radius, center+radius] @
data Ball = Ball { center :: !D.Dyadic, -- ^ center of the ball
                   radius :: !D.Dyadic -- ^ radius of the ball 
                 }
{-
instance Show Ball where
    show b@(Ball c r) = "\ncenter = " ++ D.toString dc c ++ "\n" ++ "radius = " ++ 
                        D.toString dr r
                            where dc = min 60 $ (decimalPrec . correctDigits) b
                                  dr = D.getPrec r
  -}
instance Show Ball where
    show b@(Ball c r) = s ++ "[" ++ show go ++ "]"
                        where go' = decimalPrec . correctDigits $ b
                              go  = let r' = D.getExp r in 
                                    if go' == 0 && r' < 0 then decimalPrec . fromIntegral . negate . succ $ r' else go'
                              s = D.toString go c

-- | Precision of ball\'s radius. 
radPrec :: D.Precision
radPrec = 32

-- | Create epsilon neighbourhood of d according to the number of accurate digits of d.
-- Specifically return m * 2 ^ (e - p - 1) 
createEpsilon      :: Int -- ^ m
                      -> D.Dyadic -- ^ dyadic with magnitude e and precision p
                      -> D.Dyadic
createEpsilon i d = D.int2i D.Zero radPrec i (if d == 0 then 0 else D.getExp d - (fromIntegral $ D.getPrec d) - 1)

-- | If first arugment \/= 0 then add to second argument the epsilon of the third.
addEpsilon       :: Int -- ^ indicates whether correction is necessary 
                    -> D.Dyadic -- ^ dyadic to be corrected
                    -> D.Dyadic -- ^ dyadic which indicates the magnitude of correction
                    -> D.Dyadic
addEpsilon e d x = if e /= 0 then D.add D.Up (D.getPrec d) d (createEpsilon 1 x)
                   else d

-- | Make a ball from endpoints
makeA           :: D.Precision -- ^ desired precision of the center
                  -> D.Dyadic -- ^ left endpoint
                  -> D.Dyadic -- ^ right endpoint
                  -> Ball 
makeA p d1 d2 = Ball cen rad
               where (c,e) = D.add_ D.Near p d1 d2
                     cen   = D.div2w D.Near p c 1
                     r     = D.sub D.Up radPrec d2 d1
                     r'    = D.div2w D.Up radPrec r 1
                     rad   = addEpsilon e r' c


-- | Make a ball from endpoints so that no precision is lost. 
make       :: D.Dyadic -- ^ left endpoint
              -> D.Dyadic -- ^ right endpoint
              -> Ball 
make d1 d2 = makeA (D.addPrec d1 d2) d1 d2
 
-- | Normalize the given ball's center to the specified precision.
-- Resulting ball might be larger.
normalizeBall              :: D.Precision -> Ball -> Ball
normalizeBall p (Ball c r) = Ball c' r'
                             where (c',e) = D.set_ D.Near p c
                                   r''    = D.set D.Up radPrec r
                                   r'     = addEpsilon e r'' c'

-- | MakeA a ball from dyadic. Radius is 0 if desired precision is not smaller
-- than precision of dyadic.
fromDyadic      :: D.Precision -> D.Dyadic -> Ball
fromDyadic p d = Ball c r
                     where (c, e) = D.set_ D.Near p d
                           r'     = D.fromWord D.Up radPrec 0
                           r      = addEpsilon e r' c

-- | Similar to fromDyadic.
fromInt     :: D.Precision -> Int -> Ball
fromInt p d = Ball c r
                  where (c, e) = D.fromInt_ D.Near p d
                        r'     = D.fromWord D.Up radPrec 0
                        r      = addEpsilon e r' c

-- | Similar to fromInt.
fromWord   :: D.Precision -> Word -> Ball
fromWord p = fromInt p . fromIntegral

-- | Lower endpoint of the ball rounded down to specified precision.
lower              :: D.Precision -> Ball -> D.Dyadic 
lower p (Ball c r) = D.sub D.Down p c r

-- | Upper endpoint of the ball rounded up to specified precision.
upper              :: D.Precision -> Ball -> D.Dyadic 
upper p (Ball c r) = D.add D.Up p c r

-- | Lower endpoint with precision of the center
lower_              :: Ball -> D.Dyadic
lower_ b@(Ball c _) = lower (D.getPrec c) b

-- | Upper endpoint with precision of the center
upper_              :: Ball -> D.Dyadic
upper_ b@(Ball c _) = upper (D.getPrec c) b

-- | Sign of lower endpoint of the ball. This should be faster than using @ signum (center b - radius b) @ 
sgnLower            :: Ball -> Int
sgnLower (Ball c r) = case compare c r of
                        LT -> -1
                        EQ -> 0
                        _  -> 1

-- | Analogous to sgnLower.
sgnUpper            :: Ball -> Int
sgnUpper (Ball c r) = case compare (D.neg D.Near (D.getPrec r) r) c of 
                        LT -> 1
                        EQ -> 0
                        _  -> -1

-- | Upper bound on the width of the ball. @ 2 * radius b @ rounded up.
width            :: Ball -> D.Dyadic
width (Ball _ r) = D.mul2w D.Up radPrec r 1

-- | Check if second ball is included in the first
below     :: Ball -> Ball -> Bool
below a b = lower_ a <= lower_ b && upper_ a >= upper_ b

-- | Check if dyadic is element of the ball.
contains     :: Ball -> D.Dyadic -> Bool
contains b d = lower_ b <= d && upper_ b >= d

-- | Returns an intersection of two balls. If balls are disjoint then computation fails with fail.
intersectA         :: Monad m => D.Precision -- ^ precision of the resulting ball's center
                      -> Ball -> Ball -> m Ball
intersectA p b1 b2 | l <= u = return $ makeA p l u
                   | otherwise = fail "cannot intersect disjoint intervals"
                     where l = D.maxD D.Down p (lower p b1) (lower p b2)
                           u = D.minD D.Up p (upper p b1) (upper p b2)

-- | Intersection of two balls exactly (no precision is lost).
intersect                              :: Monad m => Ball -> Ball -> m Ball
intersect b1@(Ball c _) b2@(Ball c' _) = intersectA (D.addPrec c c') b1 b2

-- | Addition of two balls.
--
-- - @ center = center a + center b @
-- 
-- - @ radius = radius a + radius b @
--
-- Rounding errors are added to the radius.
add :: D.Precision -> Ball -> Ball -> Ball
add p (Ball c r) (Ball c' r') = Ball cen rad
                                where (cen, e) = D.add_ D.Near p c c'
                                      rad      = D.add D.Up radPrec r' (addEpsilon e r cen)

-- | Subtraction of two balls.
--
-- - @ center = center a - center b @
-- 
-- - @ radius = radius a + radius b @
--
-- Rounding errors are added to the radius.
sub :: D.Precision -> Ball -> Ball -> Ball
sub p (Ball c r) (Ball c' r') = Ball cen rad
                                where (cen, e) = D.sub_ D.Near p c c'
                                      rad      = D.add D.Up radPrec r' (addEpsilon e r cen)
-- | Negation of the ball. 
--
-- - center = - center b rounded to specified precision.
-- 
-- - radius is only modified for the rounding error.
neg              :: D.Precision -> Ball -> Ball
neg p (Ball c r) = Ball c' r'
                   where (c',e) = D.neg_ D.Near p c
                         r'     = addEpsilon e r c'

absB     :: D.Precision -> Ball -> Ball
absB p b = if lb > 0 then normalizeBall p b
             else if ub < 0 then neg p b
                    else makeA p 0 (max (negate lb) (ub))
           where lb = lower_ b
                 ub = upper_ b


-- | Multiplication of two balls. (centers of both balls are assumed positive)
--
-- - If none of the balls contains 0 then
--
-- @ center = center a * center b + radius a * radius b @
--
-- @ radius = center a * radius b + radius a * center b @
-- 
-- - If one of the operands (left) contains 0
-- 
-- @ center = center a * upper b @
--
-- @ radius = radius a * upper b @
-- 
-- - If both of the balls contain 0
--
-- @ lower =  min ((lower a) * (upper b)) ((lower b) * (upper a)) @
-- 
-- @ upper =  max ((lower a) * (lower b)) ((upper b) * (upper a)) @
--
-- Rounding errors are added to the radius.
mul         :: D.Precision -> Ball -> Ball -> Ball
mul p b1 b2 = if D.sgn (center b1) * D.sgn (center b2) < 0 then neg p ret else ret
               where ret = mul' (absB p b1) (absB p b2)
                     mul' b b' = case (sgnLower b, sgnLower b') of 
                                   (1, 1) -> nonzero b b'
                                   (1, _) -> leftzero b' b
                                   (_, 1) -> leftzero b b'
                                   _      -> bothzero b b'
                     nonzero (Ball c r) (Ball c' r') = Ball cen rad 
                                                       where r'' = D.fma D.Up radPrec c r' (D.mul D.Up radPrec c' r)                                      
                                                             cr  = D.mul D.Near (2 * radPrec) r r' 
                                                             (cen, e) = D.fma_ D.Near p c c' cr
                                                             rad      = addEpsilon e r'' cen
                     leftzero (Ball c r) b = Ball cen rad
                                             where (cen, e) = D.mul_ D.Near p c up
                                                   rad      = addEpsilon e (D.mul D.Up radPrec r up) cen
                                                   up       = upper p b
                     bothzero b b' = makeA p l u
                                      where l  = D.minD D.Down p l1 l2
                                            u  = D.maxD D.Up p u1 u2
                                            l1 = D.mul D.Down p (lower p b) (upper p b')
                                            l2 = D.mul D.Down p (upper p b) (lower p b')
                                            u1 = D.mul D.Up p (lower p b) (lower p b')
                                            u2 = D.mul D.Up p (upper p b) (upper p b')

-- | Division of two balls
-- 
-- - If radius is \"large\" then divide endpoints and makeA a ball from them.
--
-- - If radius is \"small\" then division can be optimized
--
-- - @ center = center a / center b @
-- 
-- - @ (radius = radius a * center b + center a * radius b) / (center b * center b) + 2 * 2 ^ (e1 - e2 - p)@ 
--  where @ p @ is precision of the result, @ e1 = getExp c1, e2 = getExp c2 @. This way the resulting interval is 
--  guaranteed to be correct.
--
-- Rounding errors are added to the radius.
--
-- If divisor ball contains zero compuatation fails with fail.
div         :: (Monad m) => D.Precision -> Ball -> Ball -> m Ball
div p b1 b2 = if sgnLower b2 > 0 then return (div' b1 b2)
              else if sgnUpper b2 < 0 then return (neg p (div' b1 (neg p b2)))
                   else fail "Division by interval containing zero"
              where div' b b' = if smallRad b && smallRad b' then divSmall b b'
                                else divLarge b b'
                    -- radius is small if (a) it is 0 or if number of correct digits is 
                    -- more than half of the end precision
                    smallRad (Ball c r) = D.sgn r == 0 || 2 * (D.getExp c - D.getExp r) > fromIntegral p + 2
                    divSmall (Ball c r) (Ball c' r') = Ball cen rad
                                                       where cen = D.div D.Near p c c'
                                                             r'' = D.fma D.Up radPrec c r' (D.mul D.Up radPrec c' r)
                                                             (bsq, e) = D.sqr_ D.Down radPrec c'
                                                             bsq' = if e == 0 then D.nextBelow bsq else bsq
                                                             r''' = D.div D.Up radPrec r'' bsq'
                                                             -- now r''' is at most 2 * 2 ^ (e1 - e2 - p) too small
                                                             rad = D.add D.Up radPrec r''' (createEpsilon 3 cen)
                    divLarge b b' = makeA p l u
                                    where l = D.div D.Down p l1 (if D.sgn l1 < 0 then l2 else u2)
                                          u = D.div D.Up p u1 (if D.sgn u1 < 0 then u2 else l2)
                                          l1 = lower p b
                                          l2 = lower p b'
                                          u1 = upper p b
                                          u2 = upper p b'

-- | Square root of a ball. If interval contains 0 then computation fails.
sqrt                :: Monad m => D.Precision -> Ball -> m Ball
sqrt p b@(Ball c r) = if lower_ b < 0 then fail "Sqrt of a interval containing negative numbers"
                        else if radSmall then return sqrt_small
                               else return sqrt_large
                      where sqrt_large = makeA p l u
                              where l = D.sqrt D.Down p (D.sub D.Down p c r)
                                    u = D.sqrt D.Up p (D.add D.Up p c r)
                            radSmall = D.sgn c /= 0 && (D.sgn r == 0 || D.getExp c `quot` 2 - D.getExp r > fromIntegral p)
                            sqrt_small = Ball cen rad
                              where (cen,e) = D.sqrt_ D.Near p c
                                    rad'    = D.div D.Up radPrec r cen
                                    rad     = addEpsilon e rad' cen

-- | @ e ^ b @
exp              :: D.Precision -> Ball -> Ball
exp p (Ball c r) = makeA p l u
                   where l = D.exp D.Down p (D.add D.Down p c r)
                         u = D.exp D.Up p (D.add D.Up p c r)

-- | Natural logarithm of a ball. If interval contains 0 then computation fails.
log                :: Monad m => D.Precision -> Ball -> m Ball
log p b@(Ball c r) = if lower_ b <= 0 then fail "Domain of log is R+"
                       else return (makeA p l u)
                     where l = D.log D.Down p (D.add D.Down p c r)
                           u = D.log D.Up p (D.add D.Up p c r)

-- | Compare two balls.
--
-- - if upper a < lower b then Less
--
-- - if upper b < lower a then Greater 
--
-- - otherwise balls are incomparable.
compareB      :: Ball -> Ball -> POrdering
compareB b b' = if upper_ b < lower_ b' then Less
                else if lower_ b > upper_ b' then Greater
                else Incomparable

-- | Maximum of two balls, meaning:
--
-- - lower = max (lower a) (lower b) rounded down
--
-- - upper = max (upper a) (upper b) rounded up
maxB        :: D.Precision -> Ball -> Ball -> Ball
maxB p b b' = makeA p l u
              where l = D.maxD D.Down p (lower p b) (lower p b')
                    u = D.maxD D.Up p (upper p b) (upper p b')

-- | Analogous to maxB.
minB        :: D.Precision -> Ball -> Ball -> Ball
minB p b b' = makeA p l u
              where l = D.minD D.Down p (lower p b) (lower p b')
                    u = D.minD D.Up p (upper p b) (upper p b')

-- | Similar to fromDyadic.
fromString     :: D.Precision -> String -> Ball
fromString p s = Ball cen rad
                   where cen = D.fromString s p 10
                         rad = createEpsilon 1 cen

decimalPrec :: Word -> Word
decimalPrec d = floor (fromIntegral d * logBase 10 2 :: Double)

correctDigits :: Ball -> Word
correctDigits (Ball c r) =  case compare D.zero r of
                             EQ -> (fromIntegral . D.getPrec) c
                             LT -> let cd = D.getExp c - D.getExp r in fromIntegral (max 0 cd)
                             _  -> error "Ball.correctDigits: radius should be a nonnegative, non-degenerate number"