{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-|
    Module      :  Data.Number.ER.RnToRm.UnitDom.ChebyshevBase.Polynom.Ring
    Description :  (internal) uniformly roudned pointwise ring operations  
    Copyright   :  (c) 2007-2008 Michal Konecny
    License     :  BSD3

    Maintainer  :  mik@konecny.aow.cz
    Stability   :  experimental
    Portability :  portable
    
    Internal module for "Data.Number.ER.RnToRm.UnitDom.ChebyshevBase.Polynom".
    
    Implementation of addition and multiplication over polynomials 
    with pointwise rounding uniform over the whole unit domain.
-}
module Data.Number.ER.RnToRm.UnitDom.ChebyshevBase.Polynom.Ring

where

import Data.Number.ER.RnToRm.UnitDom.ChebyshevBase.Polynom.Basic

import qualified Data.Number.ER.Real.Base as B
import qualified Data.Number.ER.Real.DomainBox as DBox
import Data.Number.ER.Real.DomainBox (VariableID(..), DomainBox, DomainIntBox)
import Data.Number.ER.Misc

import qualified Data.Map as Map

{-|
    Negate a polynomial exactly.
-}
chplNeg (ERChebPoly coeffs) =
    ERChebPoly $ Map.map negate coeffs

{-|
    Add a constant to a polynomial, rounding downwards and upwards. 
-}
chplAddConst ::
    (B.ERRealBase b, DomainBox box varid Int, Ord box) => 
    b -> 
    ERChebPoly box b -> 
    (ERChebPoly box b, ERChebPoly box b, b)
        {-^ lower and upper bounds on the sum and an upper bound on their difference -}
chplAddConst c (ERChebPoly coeffs) =
    (ERChebPoly sumCoeffsDown, ERChebPoly sumCoeffsUp, err)
    where
    sumCoeffsUp =
        Map.insert chplConstTermKey newConstUp coeffs
    sumCoeffsDown =
        Map.insert chplConstTermKey newConstDown coeffs
    oldConst =
        case Map.lookup chplConstTermKey coeffs of
            Just c -> c
            Nothing -> 0
    newConstUp = oldConst `plusUp` c
    newConstDown = oldConst `plusDown` c
    err = newConstUp - newConstDown    

chplAddConstUp c p = (\(sumDown, sumUp, width) -> sumUp) $ chplAddConst c p
chplAddConstDown c p = (\(sumDown, sumUp, width) -> sumDown) $ chplAddConst c p

{-|
    Add two polynomials, rounding downwards and upwards. 
-}
chplAdd ::
    (B.ERRealBase b, DomainBox box varid Int, Ord box) => 
    ERChebPoly box b -> 
    ERChebPoly box b -> 
    (ERChebPoly box b, ERChebPoly box b, b)
        {-^ lower and upper bounds on the sum and an upper bound on their difference -}
chplAdd (ERChebPoly coeffs1) (ERChebPoly coeffs2) =
    (ERChebPoly sumCoeffsDown, ERChebPoly sumCoeffsUp, 2 * maxError)
    where
    sumCoeffsUp =
        Map.insertWith plusUp chplConstTermKey maxError coeffsDown
        -- point-wise sum of polynomials with coeff rounding errors
        -- compensated for by enlarging the constant term
    sumCoeffsDown =
        Map.insertWith plusDown chplConstTermKey (- maxError) coeffsUp
        -- point-wise sum of polynomials with coeff rounding errors
        -- compensated for by enlarging the constant term
    coeffsUp =
        (Map.unionWith plusUp coeffs1 coeffs2)
        -- point-wise sum of polynomials with coeffs rounded upwards
    coeffsDown =
        (Map.unionWith plusDown coeffs1 coeffs2)
        -- point-wise sum of polynomials with coeffs rounded upwards
    maxError =
        Map.fold plusUp 0 $ 
            Map.intersectionWith (-) coeffsUp coeffsDown
        -- addition must round upwards on interval [-1,1]
                -- non-constant terms are multiplied by quantities in [-1,1] 
                -- and thus can make the result drop below the exact result
                -- -> to compensate add the rounding difference to the constant term 

p1 +^ p2 = (\(sumDown, sumUp, width) -> sumUp) $ chplAdd p1 p2
p1 +. p2 = (\(sumDown, sumUp, width) -> sumDown) $ chplAdd p1 p2
p1 -^ p2 = p1 +^ (chplNeg p2)
p1 -. p2 = p1 +. (chplNeg p2)

{-|
    Multiply two polynomials, rounding downwards and upwards. 
-}
chplMultiply ::
    (B.ERRealBase b, DomainBox box varid Int, Ord box) => 
    ERChebPoly box b -> 
    ERChebPoly box b -> 
    (ERChebPoly box b, ERChebPoly box b, b) 
        {-^ lower and upper bounds on the product and an upper bound on their difference -}
chplMultiply p1@(ERChebPoly coeffs1) p2@(ERChebPoly coeffs2) =
    case (chplGetConst p1, chplGetConst p2) of
        (Just c1, _) -> chplScale c1 p2
        (_, Just c2) -> chplScale c2 p1
        _ ->    
            (ERChebPoly prodCoeffsDown, ERChebPoly prodCoeffsUp, 2 * roundOffCompensation)
    where
    prodCoeffsUp =
        Map.insertWith plusUp chplConstTermKey roundOffCompensation $ 
            Map.map negate directProdCoeffsDownNeg
    prodCoeffsDown =
        Map.insertWith plusDown chplConstTermKey (- roundOffCompensation) $ 
            directProdCoeffsUp
    roundOffCompensation =
        Map.fold plusUp 0 $
            Map.unionWith plusUp directProdCoeffsUp directProdCoeffsDownNeg
    (directProdCoeffsUp, directProdCoeffsDownNeg) =
        foldl addCombiCoeff (Map.empty, Map.empty) combinedCoeffs
        where
        addCombiCoeff
                (prevCoeffsUp, prevCoeffsDownNeg) 
                (coeffUp, coeffDownNeg, (powersList, coeffCount)) =
            foldl addOnce (prevCoeffsUp, prevCoeffsDownNeg) powersList
            where
            addOnce (prevCoeffsUp, prevCoeffsDownNeg) powers =
                (Map.insertWith plusUp powers coeffUpFrac prevCoeffsUp, 
                 Map.insertWith plusUp powers coeffDownNegFrac prevCoeffsDownNeg)
            coeffUpFrac = coeffUp / coeffCountB
            coeffDownNegFrac = coeffDownNeg / coeffCountB
            coeffCountB = fromInteger coeffCount
    combinedCoeffs =
        [   -- (list of triples)
            (
                (c1 * c2) -- upwards rounded product
            ,
                ((- c1) * c2) -- downwards rounded negated product
            ,
                combinePowers powers1 powers2
            )
        |
            (powers1, c1) <- coeffs1List,
            (powers2, c2) <- coeffs2List
        ]
    combinePowers powers1 powers2 =
        (combinedPowers, 2 ^ (length sumsDiffs)) 
        where
        combinedPowers =
            map (DBox.fromAscList . (filter $ \ (k,v) -> v > 0)) $
                allPairsCombinations $ 
                    sumsDiffs
        sumsDiffs = 
            -- associative list with the sum and difference of powers for each variable
            zipWith (\(k,s) (_,d) -> (k,(s,d)))
                (DBox.toAscList $ DBox.unionWith (\a b -> (a + b)) powers1 powers2)
                (DBox.toAscList $ DBox.unionWith (\a b -> abs (a - b)) powers1 powers2)
    coeffs1List =
        Map.toList coeffs1
    coeffs2List =
        Map.toList coeffs2

p1 *^ p2 = (\(prodDown,prodUp,width) -> prodUp) $ chplMultiply p1 p2
p1 *. p2 = (\(prodDown,prodUp,width) -> prodDown) $ chplMultiply p1 p2

{-| Multiply a polynomial by a scalar rounding downwards and upwards. -} 
chplScale ::
    (B.ERRealBase b, DomainBox box varid Int, Ord box) =>
    b -> 
    (ERChebPoly box b) -> 
    (ERChebPoly box b, ERChebPoly box b, b)
        {-^ lower and upper bounds on the product and an upper bound on their difference -}
chplScale ratio p@(ERChebPoly coeffs) =
    case chplGetConst p of
        Just c -> 
            (chplConst cScaledDown, chplConst cScaledUp, cScaledUp - cScaledDown)
            where
            cScaledUp = ratio `timesUp` c
            cScaledDown = ratio `timesDown` c
        _ -> 
            (ERChebPoly coeffsDown, ERChebPoly coeffsUp, 2 * errBound)
    where
    coeffsDown = 
        Map.insertWith plusDown chplConstTermKey (- errBound) coeffsScaled 
    coeffsUp = 
        Map.insertWith plusUp chplConstTermKey errBound coeffsScaled
    (errBound, coeffsScaled) =
        Map.mapAccum processTerm 0 coeffs
    processTerm errBoundPrev coeff =
        (errBoundPrev + errBoundHere, coeffScaledUp)
        where
        errBoundHere = coeffScaledUp - coeffScaledDown
        coeffScaledDown = ratio `timesDown` coeff
        coeffScaledUp = ratio `timesUp` coeff    

chplScaleDown r p = (\(prodDown,prodUp,width) -> prodDown) $  chplScale r p
chplScaleUp r p = (\(prodDown,prodUp,width) -> prodUp) $ chplScale r p

{-|
    Multiply a polynomial by itself, rounding downwards and upwards.
-}
chplSquare ::
    (B.ERRealBase b, DomainBox box varid Int, Ord box) => 
    ERChebPoly box b ->
    (ERChebPoly box b, ERChebPoly box b)
chplSquare p =
    (p2Down, p2Up)
    where
    (p2Down, p2Up, wd) = chplMultiply p p