{-# LANGUAGE DeriveDataTypeable   #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-|
    Module      :  Data.Number.ER.Real.Approx.Interval
    Description :  safe interval arithmetic
    Copyright   :  (c) Michal Konecny
    License     :  BSD3

    Maintainer  :  mik@konecny.aow.cz
    Stability   :  experimental
    Portability :  portable

    This module defines an arbitrary precision interval type and
    most of its interval arithmetic operations.
-}
module Data.Number.ER.Real.Approx.Interval 
(
    ERInterval(..),
    normaliseERInterval
)
where

import qualified Data.Number.ER.Real.Approx as RA
import qualified Data.Number.ER.Real.Approx.Elementary as RAEL
import qualified Data.Number.ER.Real.Base as B
import qualified Data.Number.ER.ExtendedInteger as EI

import Data.Number.ER.BasicTypes
import Data.Number.ER.Misc

import Data.Ratio

import qualified Text.Html as H

import Data.Typeable
import Data.Generics.Basics
import Data.Binary
--import BinaryDerive

{-|
    Type for arbitrary precision interval arithmetic.
-}
data ERInterval base =
    ERIntervalEmpty -- ^ usually represents computation error (top element in the interval domain)
    | ERIntervalAny  -- ^ represents no knowledge of result (bottom element in the interval domain) 
    | ERInterval
    {
        erintv_left :: base,
        erintv_right :: base
    }
    deriving (Typeable, Data)
    
{- the following has been generated by BinaryDerive -}
instance (Binary a) => Binary (ERInterval a) where
  put ERIntervalEmpty = putWord8 0
  put ERIntervalAny = putWord8 1
  put (ERInterval a b) = putWord8 2 >> put a >> put b
  get = do
    tag_ <- getWord8
    case tag_ of
      0 -> return ERIntervalEmpty
      1 -> return ERIntervalAny
      2 -> get >>= \a -> get >>= \b -> return (ERInterval a b)
      _ -> fail "no parse"
{- the above has been generated by BinaryDerive -}
    
    
{-|
    convert to a normal form, ie:
    
    * no NaNs as endpoints
    
    * @l <= r@
    
    * no (-Infty, +Infty)
-}
normaliseERInterval :: 
    (B.ERRealBase b) => 
    ERInterval b -> ERInterval b
normaliseERInterval (ERInterval minusInfty plusInfty) 
    | B.isPlusInfinity plusInfty && B.isPlusInfinity (- minusInfty) = 
        ERIntervalAny
normaliseERInterval (ERInterval nan1 nan2) 
    | B.isERNaN nan1 && B.isERNaN nan2 =
        ERIntervalAny
normaliseERInterval (ERInterval nan r) 
    | B.isERNaN nan = 
        ERInterval (- B.plusInfinity) r
normaliseERInterval (ERInterval l nan) 
    | B.isERNaN nan = 
        ERInterval l (B.plusInfinity)
normaliseERInterval (ERInterval l r)
    | l > r = ERIntervalEmpty
normaliseERInterval i = i

{-|
    erintvPrecision returns an approximation of the number of bits required
    to represent the mantissa of a normalised size of the interval:

  
  >  - log_2 ((r - l) / (1 + abs(r) + abs(l)))
    
    Notice that this is +Infty for singleton and empty intervals
    and -Infty for the whole real line.
-}    
erintvPrecision :: 
    (B.ERRealBase b) => 
    ERInterval b -> EI.ExtendedInteger
erintvPrecision (ERInterval l r) =
    -1 - (B.getApproxBinaryLog $ (r - l)) -- /(1 + abs r + abs l))
erintvPrecision ERIntervalEmpty = EI.PlusInfinity
erintvPrecision ERIntervalAny = EI.MinusInfinity

erintvGranularity :: 
    (B.ERRealBase b) => 
    ERInterval b -> Int
erintvGranularity ERIntervalAny = 0
erintvGranularity ERIntervalEmpty = 0
erintvGranularity (ERInterval l r) =
    min (B.getGranularity l) (B.getGranularity r)

{- syntactic comparisons -}

{-|
    a syntactic equality test
-}
erintvEqualApprox :: 
    (B.ERRealBase b) => 
    ERInterval b -> ERInterval b -> Bool
erintvEqualApprox (ERInterval l1 r1) (ERInterval l2 r2) =
    l1 == l2 && r1 == r2
erintvEqualApprox ERIntervalEmpty ERIntervalEmpty = True
erintvEqualApprox ERIntervalAny ERIntervalAny = True
erintvEqualApprox _ _ = False

{-|
    a syntactic linear order
-}
erintvCompareApprox :: 
    (B.ERRealBase b) => 
    ERInterval b -> ERInterval b -> Ordering
erintvCompareApprox ERIntervalEmpty ERIntervalEmpty = EQ
erintvCompareApprox ERIntervalEmpty _ = LT
erintvCompareApprox _ ERIntervalEmpty = GT
erintvCompareApprox ERIntervalAny ERIntervalAny = EQ
erintvCompareApprox ERIntervalAny _ = LT
erintvCompareApprox _ ERIntervalAny = GT
erintvCompareApprox (ERInterval l1 r1) (ERInterval l2 r2) =
    case compare l1 l2 of
        EQ -> compare r1 r2
        res -> res

{- semantic comparisons -}

{-|
    Compare for equality two intervals interpreted as approximations for
    2 single real numbers.  When equality or inequality cannot
    be established, return Nothing (ie bottom).
-}
erintvEqualReals ::
    (B.ERRealBase b) =>
    ERInterval b ->
    ERInterval b ->
    Maybe Bool
erintvEqualReals ERIntervalEmpty _ = Nothing
erintvEqualReals _ ERIntervalEmpty = Nothing
erintvEqualReals ERIntervalAny _ = Nothing
erintvEqualReals _ ERIntervalAny = Nothing
erintvEqualReals (ERInterval l1 r1) (ERInterval l2 r2)
    | l1 == r1 && l2 == r2 && l1 == l2 = Just True
    | r1 < l2 || l1 > r2 = Just False
    | otherwise = Nothing

{-|
    Compare in natural order two intervals interpreted as approximations for
    2 single real numbers.  When equality or inequality cannot
    be established, return Nothing (ie bottom).
-}
erintvCompareReals ::
    (B.ERRealBase b) =>
    ERInterval b ->
    ERInterval b ->
    Maybe Ordering
erintvCompareReals ERIntervalEmpty _ = Nothing
erintvCompareReals _ ERIntervalEmpty = Nothing
erintvCompareReals ERIntervalAny _ = Nothing
erintvCompareReals _ ERIntervalAny = Nothing
erintvCompareReals i1@(ERInterval l1 r1) i2@(ERInterval l2 r2)
    | r1 < l2 = Just LT
    | l1 > r2 = Just GT
    | l1 == r1 && l2 == r2 && l1 == l2 = Just EQ
    | otherwise = Nothing

{-|
    Compare in natural order two intervals interpreted as approximations for
    2 single real numbers.  When relaxed equality cannot
    be established nor disproved, return Nothing (ie bottom).
-}
erintvLeqReals ::
    (B.ERRealBase b) =>
    ERInterval b ->
    ERInterval b ->
    Maybe Bool
erintvLeqReals ERIntervalEmpty _ = Nothing
erintvLeqReals _ ERIntervalEmpty = Nothing
erintvLeqReals ERIntervalAny _ = Nothing
erintvLeqReals _ ERIntervalAny = Nothing
erintvLeqReals i1@(ERInterval l1 r1) i2@(ERInterval l2 r2)
    | r1 <= l2 = Just True
    | l1 > r2 = Just False
    | otherwise = Nothing


{-|
    
    Default splitting:

    > [-Infty,+Infty] |-> [-Infty,0] [0,+Infty] 
    
    > [-Infty,x] |-> [-Infty,2*x-1] [2*x-1, x] (x <= 0)
    
    > [-Infty,x] |-> [-Infty,0] [0, x] (x > 0)
    
    > [x,+Infty] |-> [x,2*x+1] [2*x+1,+Infty]  (x => 0)
    
    > [x,+Infty] |-> [x,0] [0,+Infty]  (x < 0)
    
    > [x,y] |-> [x, (x+y)/2] [(x+y)/2, y]
    
    > empty |-> empty empty
-}
erintvDefaultBisectPt ::
    (B.ERRealBase b) => 
    Granularity -> 
    (ERInterval b) ->
    (ERInterval b)
erintvDefaultBisectPt gran ERIntervalAny = 0
erintvDefaultBisectPt gran ERIntervalEmpty = ERIntervalEmpty
erintvDefaultBisectPt gran (ERInterval l r) =
    ERInterval m m
    where
    m
        | B.isPlusInfinity r =
            if l < 0 
                then 0
                else 2 * (B.setMinGranularity gran l) + 1
        | B.isPlusInfinity (-l) =
            if r > 0 
                then 0
                else 2 * (B.setMinGranularity gran r) - 1
        | otherwise =
             ((B.setMinGranularity gran l) + r)/2
    

erintvBisect ::
    (B.ERRealBase b, RealFrac b) => 
    Granularity -> 
    (Maybe (ERInterval b)) ->
    (ERInterval b) ->
    (ERInterval b, ERInterval b)
erintvBisect gran maybePt i =
    (l RA.\/ m, m RA.\/ r)
    where
    (l,r) = RA.bounds i
    m =
        case maybePt of
            Just m -> m
            Nothing -> erintvDefaultBisectPt gran i 

instance (B.ERRealBase b) => Eq (ERInterval b) where
    i1 == i2 =
        case erintvEqualReals i1 i2 of
            Nothing -> 
                error $
                     "ERInterval: Eq: comparing overlapping intervals:\n" ++
                    show i1 ++ "\n" ++
                    show i2
            Just b -> b

instance (B.ERRealBase b) => Ord (ERInterval b) where
    compare i1 i2 = 
        case erintvCompareReals i1 i2 of
            Nothing -> 
                error $ 
                    "ERInterval: Ord: comparing overlapping intervals:\n" ++
                    show i1 ++ "\n" ++
                    show i2
            Just r -> r
    {- max:
       (Default implementation is wrong in this case:
        eg compare is not defined for overlapping intervals.)
    -}
    max i1@(ERInterval l1 r1) i2@(ERInterval l2 r2) =
        normaliseERInterval $ ERInterval (max l1 l2) (max r1 r2)
    max ERIntervalEmpty _ = ERIntervalEmpty
    max _ ERIntervalEmpty = ERIntervalEmpty
    max ERIntervalAny ERIntervalAny = ERIntervalAny
    max ERIntervalAny (ERInterval l r) = ERInterval l B.plusInfinity
    max (ERInterval l r) ERIntervalAny = ERInterval l B.plusInfinity
    {- min: -}
    min i1@(ERInterval l1 r1) i2@(ERInterval l2 r2) =
        normaliseERInterval $ ERInterval (min l1 l2) (min r1 r2)
    min ERIntervalEmpty _ = ERIntervalEmpty
    min _ ERIntervalEmpty = ERIntervalEmpty
    min ERIntervalAny ERIntervalAny = ERIntervalAny
    min ERIntervalAny (ERInterval l r) = ERInterval (- B.plusInfinity) r
    min (ERInterval l r) ERIntervalAny = ERInterval (- B.plusInfinity) r
        
instance (B.ERRealBase b) => Show (ERInterval b) 
    where
    show = erintvShow 16 True False
    
erintvShow numDigits showGran showComponents interval =
    showERI interval
    where
    showERI ERIntervalEmpty = "[NONE]"
    showERI ERIntervalAny = "[ANY]"
    showERI (ERInterval l r) 
        | l == r = "<" ++ showBase l ++ ">"
        | otherwise = 
            "[" ++ showBase l ++ "," ++ showBase r ++ "]"
    showBase = B.showDiGrCmp numDigits showGran showComponents
        
instance (B.ERRealBase b, H.HTML b) => H.HTML (ERInterval b)
    where
    toHtml (ERInterval l r) 
        | l == r =
            H.toHtml $ show l
        | otherwise =
            H.simpleTable [] [] [[H.toHtml l],[H.toHtml r]]
    toHtml i = H.toHtml $ show i 

instance (B.ERRealBase b) => Num (ERInterval b) where
    fromInteger n =
        normaliseERInterval $ ERInterval (fromInteger n) (fromInteger n)
    {- abs -}
    abs (ERInterval l r)
        | l < 0 && r > 0 = ERInterval 0 (max (-l) r)
        | r <= 0 = ERInterval (-r) (-l)
        | otherwise = ERInterval l r
    abs ERIntervalAny = ERInterval 0 B.plusInfinity
    abs ERIntervalEmpty = ERIntervalEmpty
    {- signum -}
    signum i@(ERInterval l r)
        | l < 0 && r > 0 = ERInterval (-1) 1 -- need many-valuedness via sequences of intervals
        | r < 0 = ERInterval (-1) (-1)
        | l > 0 = ERInterval 1 1
        | l == 0 && r == 0 = i
        | l == 0 = ERInterval 0 1
        | r == 0 = ERInterval (-1) 0
    signum ERIntervalAny = ERInterval (-1) 1
    signum ERIntervalEmpty = ERIntervalEmpty
    {- negate -}
    negate (ERInterval l r) = (ERInterval (-r) (-l))
    negate ERIntervalEmpty = ERIntervalEmpty
    negate ERIntervalAny = ERIntervalAny
    {- addition -}
    (ERInterval l1 r1) + (ERInterval l2 r2) =
        normaliseERInterval $
        ERInterval 
            (-((-l1) + (-l2))) -- reverse the rounding mode
            (r1 + r2)
    ERIntervalAny + i2 = ERIntervalAny
    i1 + ERIntervalAny = ERIntervalAny
    ERIntervalEmpty + i2 = ERIntervalEmpty
    i1 + ERIntervalEmpty = ERIntervalEmpty
    {- multiplication -}
    (ERInterval l1 r1) * (ERInterval l2 r2)
        | haveNan = ERIntervalAny
        | otherwise =
            normaliseERInterval $
            ERInterval minProd maxProd
        where
        haveNan = or $ map B.isERNaN (prodsL ++ prodsR)
        minProd = foldl1 min prodsL
        maxProd = foldl1 max prodsR
        prodsL = [-((-l1) * l2), -((-l1) * r2), -((-r1) * l2), -((-r1) * r2)]
        prodsR = [l1 * l2, l1 * r2, r1 * l2, r1 * r2]
    ERIntervalAny * i2 = ERIntervalAny
    i1 * ERIntervalAny = ERIntervalAny
    ERIntervalEmpty * i2 = ERIntervalEmpty
    i1 * ERIntervalEmpty = ERIntervalEmpty

instance (B.ERRealBase b) => Fractional (ERInterval b) where
    fromRational rat =
        (fromInteger $ numerator rat)
        / (fromInteger $ denominator rat)
    {- division -}
    (ERInterval l1 r1) / (ERInterval l2 r2)
        | l2 < 0 && r2 > 0 = ERIntervalAny
        | haveNan = 
--            unsafePrint "ERInterval: /: haveNan" $ 
            ERIntervalAny
        | l2 == 0 && r2 > 0 && 1/l2 < 0 = -- minus 0
            (ERInterval l1 r1) / (ERInterval (-l2) r2) -- correct it to +0
        | r2 == 0 && l2 < 0 && 1/r2 > 0 = -- plus 0
            (ERInterval l1 r1) / (ERInterval l2 (-r2)) -- correct it to -0
        | otherwise =
            normaliseERInterval $
            ERInterval minDiv maxDiv
        where
        haveNan = or $ map B.isERNaN (divsL ++ divsR)
        minDiv = foldl1 min divsL
        maxDiv = foldl1 max divsR
        divsL = [-(l1 / (-l2)), -(l1 / (-r2)), -(r1 / (-l2)), -(r1 / (-r2))]
        divsR = [l1 / l2, l1 / r2, r1 / l2, r1 / r2]
    ERIntervalAny / i2 = ERIntervalAny
    i1 / ERIntervalAny = ERIntervalAny
    ERIntervalEmpty / i2 = ERIntervalEmpty
    i1 / ERIntervalEmpty = ERIntervalEmpty
            
instance (B.ERRealBase b, RealFrac b) => RA.ERApprox (ERInterval b) where
    initialiseBaseArithmetic _ =
        B.initialiseBaseArithmetic (0 :: b)
    getPrecision i = erintvPrecision i
    getGranularity i = erintvGranularity i
    {- setMinGranularity -}
    setMinGranularity gr (ERInterval l r) =
        normaliseERInterval $
        (ERInterval (- (B.setMinGranularity gr (-l))) (B.setMinGranularity gr r))
    setMinGranularity _ i = i
    {- setGranularity -}
    setGranularity gr (ERInterval l r) =
        normaliseERInterval $
        (ERInterval (- (B.setGranularity gr (-l))) (B.setGranularity gr r))
    setGranularity _ i = i
    {- bottomApprox -}  
    bottomApprox = ERIntervalAny
    {- emptyApprox -}  
    emptyApprox = ERIntervalEmpty
    {- isEmpty -}
    isEmpty ERIntervalEmpty = True
    isEmpty _ = False
    {- isBottom -}
    isBottom ERIntervalAny = True
    isBottom (ERInterval l r) =
        B.isPlusInfinity r && B.isPlusInfinity (-l)
    isBottom _ = False
    {- isExact -}
    isExact ERIntervalEmpty = False
    isExact ERIntervalAny = False
    isExact (ERInterval l r) = l == r
    {- isBounded -}
    isBounded ERIntervalEmpty = True
    isBounded ERIntervalAny = False
    isBounded (ERInterval l r) = 
        (- B.plusInfinity) < l && r < B.plusInfinity
    {- intersection -}
    ERIntervalEmpty /\ i = ERIntervalEmpty
    i /\ ERIntervalEmpty = ERIntervalEmpty
    ERIntervalAny /\ i = i
    i /\ ERIntervalAny = i
    (ERInterval l1 r1) /\ (ERInterval l2 r2) =
        normaliseERInterval $
        ERInterval (max l1 l2) (min r1 r2)
    {- intersectMeasureImprovement -}
    intersectMeasureImprovement _ ERIntervalEmpty i = (ERIntervalEmpty, 1)
    intersectMeasureImprovement _ i ERIntervalEmpty = (ERIntervalEmpty, 1)
    intersectMeasureImprovement _ ERIntervalAny i = (i, 1)
    intersectMeasureImprovement _ i ERIntervalAny = (i, 1)
    intersectMeasureImprovement ix i1 i2 =
        (isec, impr)
        where
        isec = i1 RA./\ i2
        impr 
            | 0 `RA.refines` isecWidth && 0 `RA.refines` i1Width = 1 -- 0 -> 0 is no improvement
            | otherwise = i1Width / isecWidth 
        i1Width = i1H - i1L
        isecWidth = isecH - isecL
        (isecL, isecH) = RA.bounds $ RA.setMinGranularity gran isec  
        (i1L, i1H) = RA.bounds $ RA.setMinGranularity gran i1
        gran = effIx2gran ix  
    {- refines -}
    refines _ ERIntervalAny = True
    refines ERIntervalEmpty _ = True
    refines ERIntervalAny _ = False
    refines _ ERIntervalEmpty = False
    refines (ERInterval l1 r1) (ERInterval l2 r2) =
        l2 <= l1 && r1 <= r2
    {- semantic comparisons -}
    equalReals = erintvEqualReals
    compareReals = erintvCompareReals
    leqReals = erintvLeqReals
    {- non-semantic comparisons -}
    equalApprox = erintvEqualApprox
    compareApprox = erintvCompareApprox
    {- conversion from Double -}
    double2ra d = 
        ERInterval b b
        where
        b = B.fromDouble d
    {- formatting -}
    showApprox = erintvShow

instance (B.ERRealBase b, RealFrac b) => RA.ERIntApprox (ERInterval b)
    where
    doubleBounds ERIntervalAny = (- infinity, infinity)
        where
        infinity = 1/0
    doubleBounds ERIntervalEmpty = 
        error "ERInterval: doubleBounds: empty interval"
    doubleBounds (ERInterval l r) =
        (B.toDouble l, B.toDouble r) 
    floatBounds ERIntervalAny = (- infinity, infinity)
        where
        infinity = 1/0
    floatBounds ERIntervalEmpty = 
        error "ERInterval: floatBounds: empty interval"
    floatBounds (ERInterval l r) =
        (B.toFloat l, B.toFloat r) 
    integerBounds ERIntervalAny = (- infinity, infinity)
        where
        infinity = EI.PlusInfinity
    integerBounds ERIntervalEmpty = 
        error "ERInterval: integerBounds: empty interval"
    integerBounds (ERInterval l r) = 
        (- (mkEI (- l)), mkEI r)
        where
        mkEI f 
            | B.isPlusInfinity f = EI.PlusInfinity
            | B.isPlusInfinity (-f) = EI.MinusInfinity
            | otherwise = ceiling f
    defaultBisectPt dom = erintvDefaultBisectPt  (RA.getGranularity dom + 1) dom
    bisectDomain maybePt dom = 
        erintvBisect (RA.getGranularity dom + 1) maybePt dom
    {- \/ -}
    ERIntervalEmpty \/ i = i
    i \/ ERIntervalEmpty = i
    ERIntervalAny \/ _ = ERIntervalAny
    _ \/ ERIntervalAny = ERIntervalAny
    (ERInterval l1 r1) \/ (ERInterval l2 r2) =
        normaliseERInterval $
        ERInterval (min l1 l2) (max r1 r2)
    {- RA.bounds -}
    bounds ERIntervalAny = 
        (ERInterval (-B.plusInfinity) (-B.plusInfinity), 
         ERInterval B.plusInfinity B.plusInfinity)
    bounds ERIntervalEmpty = (ERIntervalEmpty, ERIntervalEmpty)
    bounds (ERInterval l r) = 
        (ERInterval l l, ERInterval r r)

instance (B.ERRealBase b, RealFrac b) => RAEL.ERApproxElementary (ERInterval b)
-- all operations here have appropriate default implementations