{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}

{- |
Power series, either finite or unbounded.  (zipWith does exactly the
right thing to make it work almost transparently.)
-}

module MathObj.PowerSeries where

import qualified MathObj.Polynomial     as Poly

import qualified Algebra.Differential   as Differential
import qualified Algebra.IntegralDomain as Integral
import qualified Algebra.VectorSpace    as VectorSpace
import qualified Algebra.Module         as Module
import qualified Algebra.Vector         as Vector
import qualified Algebra.Transcendental as Transcendental
import qualified Algebra.Algebraic      as Algebraic
import qualified Algebra.Field          as Field
import qualified Algebra.Ring           as Ring
import qualified Algebra.Additive       as Additive
import qualified Algebra.ZeroTestable   as ZeroTestable

import Algebra.Module((*>))
import Algebra.ZeroTestable(isZero)

import qualified Data.List.Match as Match
import qualified NumericPrelude as NP
import qualified PreludeBase as P

import PreludeBase    hiding (const)
import NumericPrelude hiding (negate, stdUnit, divMod,
                              sqrt, exp, log,
                              sin, cos, tan, asin, acos, atan)


newtype T a = Cons {coeffs :: [a]} deriving (Ord)

{-# INLINE fromCoeffs #-}
fromCoeffs :: [a] -> T a
fromCoeffs = lift0

{-# INLINE lift0 #-}
lift0 :: [a] -> T a
lift0 = Cons

{-# INLINE lift1 #-}
lift1 :: ([a] -> [a]) -> (T a -> T a)
lift1 f (Cons x0) = Cons (f x0)

{-# INLINE lift2 #-}
lift2 :: ([a] -> [a] -> [a]) -> (T a -> T a -> T a)
lift2 f (Cons x0) (Cons x1) = Cons (f x0 x1)

{-# INLINE const #-}
const :: a -> T a
const x = lift0 [x]

{-
Functor instance is e.g. useful for showing power series in residue rings.
@fmap (ResidueClass.concrete 7) (powerSeries [1,4,4::ResidueClass.T Integer] * powerSeries [1,5,6])@
-}

instance Functor T where
  fmap f (Cons xs) = Cons (map f xs)

{-# INLINE appPrec #-}
appPrec :: Int
appPrec  = 10

instance (Show a) => Show (T a) where
  showsPrec p (Cons xs) =
    showParen (p >= appPrec) (showString "PowerSeries.fromCoeffs " . shows xs)


{-# INLINE truncate #-}
truncate :: Int -> T a -> T a
truncate n = lift1 (take n)

{- |
Evaluate (truncated) power series.
-}
{-# INLINE eval #-}
eval :: Ring.C a => [a] -> a -> a
eval = flip Poly.horner

{-# INLINE evaluate #-}
evaluate :: Ring.C a => T a -> a -> a
evaluate (Cons y) = eval y

{- |
Evaluate (truncated) power series.
-}
{-# INLINE evalCoeffVector #-}
evalCoeffVector :: Module.C a v => [v] -> a -> v
evalCoeffVector = flip Poly.hornerCoeffVector

{-# INLINE evaluateCoeffVector #-}
evaluateCoeffVector :: Module.C a v => T v -> a -> v
evaluateCoeffVector (Cons y) = evalCoeffVector y


{-# INLINE evalArgVector #-}
evalArgVector :: (Module.C a v, Ring.C v) => [a] -> v -> v
evalArgVector = flip Poly.hornerArgVector

{-# INLINE evaluateArgVector #-}
evaluateArgVector :: (Module.C a v, Ring.C v) => T a -> v -> v
evaluateArgVector (Cons y) = evalArgVector y

{- |
Evaluate approximations that is evaluate all truncations of the series.
-}
{-# INLINE approx #-}
approx :: Ring.C a => [a] -> a -> [a]
approx y x =
   scanl (+) zero (zipWith (*) (iterate (x*) 1) y)

{-# INLINE approximate #-}
approximate :: Ring.C a => T a -> a -> [a]
approximate (Cons y) = approx y


{- |
Evaluate approximations that is evaluate all truncations of the series.
-}
{-# INLINE approxCoeffVector #-}
approxCoeffVector :: Module.C a v => [v] -> a -> [v]
approxCoeffVector y x =
   scanl (+) zero (zipWith (*>) (iterate (x*) 1) y)

{-# INLINE approximateCoeffVector #-}
approximateCoeffVector :: Module.C a v => T v -> a -> [v]
approximateCoeffVector (Cons y) = approxCoeffVector y


{- |
Evaluate approximations that is evaluate all truncations of the series.
-}
{-# INLINE approxArgVector #-}
approxArgVector :: (Module.C a v, Ring.C v) => [a] -> v -> [v]
approxArgVector y x =
   scanl (+) zero (zipWith (*>) y (iterate (x*) 1))

{-# INLINE approximateArgVector #-}
approximateArgVector :: (Module.C a v, Ring.C v) => T a -> v -> [v]
approximateArgVector (Cons y) = approxArgVector y


{- * Simple series manipulation -}

{- |
For the series of a real function @f@
compute the series for @\x -> f (-x)@
-}

alternate :: Additive.C a => [a] -> [a]
alternate = zipWith id (cycle [id, NP.negate])

{- |
For the series of a real function @f@
compute the series for @\x -> (f x + f (-x)) \/ 2@
-}

holes2 :: Additive.C a => [a] -> [a]
holes2 = zipWith id (cycle [id, P.const zero])

{- |
For the series of a real function @f@
compute the real series for @\x -> (f (i*x) + f (-i*x)) \/ 2@
-}
holes2alternate :: Additive.C a => [a] -> [a]
holes2alternate =
   zipWith id (cycle [id, P.const zero, NP.negate, P.const zero])


{- * Series arithmetic -}

add, sub :: (Additive.C a) => [a] -> [a] -> [a]
add = Poly.add
sub = Poly.sub

negate :: (Additive.C a) => [a] -> [a]
negate = Poly.negate

scale :: Ring.C a => a -> [a] -> [a]
scale = Poly.scale

mul :: Ring.C a => [a] -> [a] -> [a]
mul = Poly.mul

{-
Note that the derived instances only make sense for finite series.
-}

instance (Eq a, ZeroTestable.C a) => Eq (T a) where
    (Cons x) == (Cons y) = Poly.equal x y

instance (Additive.C a) => Additive.C (T a) where
    negate = lift1 Poly.negate
    (+)    = lift2 Poly.add
    (-)    = lift2 Poly.sub
    zero   = lift0 []

instance (Ring.C a) => Ring.C (T a) where
    one           = const one
    fromInteger n = const (fromInteger n)
    (*)           = lift2 mul

instance Vector.C T where
   zero  = zero
   (<+>) = (+)
   (*>)  = Vector.functorScale

instance (Module.C a b) => Module.C a (T b) where
    (*>) x = lift1 (x *>)

instance (Field.C a, Module.C a b) => VectorSpace.C a (T b)

stripLeadZero :: (ZeroTestable.C a) => [a] -> [a] -> ([a],[a])
stripLeadZero (x:xs) (y:ys) =
  if isZero x && isZero y
    then stripLeadZero xs ys
    else (x:xs,y:ys)
stripLeadZero xs ys = (xs,ys)

{- |
Divide two series where the absolute term of the divisor is non-zero.
That is, power series with leading non-zero terms are the units
in the ring of power series.

Knuth: Seminumerical algorithms
-}
divide :: (Field.C a) => [a] -> [a] -> [a]
divide (x:xs) (y:ys) =
   let zs = map (/y) (x : sub xs (mul zs ys))
   in  zs
divide [] _ = []
divide _ [] = error "PowerSeries.divide: division by empty series"

{- |
Divide two series also if the divisor has leading zeros.
-}
divideStripZero :: (ZeroTestable.C a, Field.C a) => [a] -> [a] -> [a]
divideStripZero x' y' =
   let (x0,y0) = stripLeadZero x' y'
   in  if null y0 || isZero (head y0)
         then error "PowerSeries.divideStripZero: Division by zero."
         else divide x0 y0


instance (Field.C a) => Field.C (T a) where
  (/) = lift2 divide


divMod :: (ZeroTestable.C a, Field.C a) => [a] -> [a] -> ([a],[a])
divMod xs ys =
   let (yZero,yRem) = span isZero ys
       (xMod, xRem) = Match.splitAt yZero xs
   in  (divide xRem yRem, xMod)

instance (ZeroTestable.C a, Field.C a) => Integral.C (T a) where
  divMod (Cons x) (Cons y) =
     let (d,m) = divMod x y
     in  (Cons d, Cons m)


progression :: Ring.C a => [a]
progression = Poly.progression

recipProgression :: (Field.C a) => [a]
recipProgression = map recip progression

differentiate :: (Ring.C a) => [a] -> [a]
differentiate = Poly.differentiate

integrate :: (Field.C a) => a -> [a] -> [a]
integrate = Poly.integrate

instance (Ring.C a) => Differential.C (T a) where
  differentiate = lift1 differentiate


{- |
We need to compute the square root only of the first term.
That is, if the first term is rational,
then all terms of the series are rational.
-}

sqrt :: Field.C a => (a -> a) -> [a] -> [a]
sqrt _ [] = []
sqrt f0 (x:xs) =
   let y  = f0 x
       ys = map (/(y+y)) (xs - (0 : mul ys ys))
   in  y:ys

{-
pow alpha t = t^alpha
(pow alpha . x)' = alpha * (pow (alpha-1) . x) * x'
alpha * (pow alpha . x) = x * x' * (pow alpha . x)'
y = pow alpha . x
alpha * y = x * x' * y'
-}

{- |
Input series must start with non-zero term.
-}
pow :: (Field.C a) => (a -> a) -> a -> [a] -> [a]
pow f0 expon x =
   let y  = integrate (f0 (head x)) y'
       y' = scale expon (divide y (mul x (differentiate x)))
   in  y

instance (Algebraic.C a) => Algebraic.C (T a) where
   sqrt   = lift1 (sqrt Algebraic.sqrt)
   x ^/ y = lift1 (pow (Algebraic.^/ y)
                       (fromRational' y)) x

{- |
The first term needs a transcendent computation but the others do not.
That's why we accept a function which computes the first term.

> (exp . x)' =   (exp . x) * x'
> (sin . x)' =   (cos . x) * x'
> (cos . x)' = - (sin . x) * x'
-}

exp :: Field.C a => (a -> a) -> [a] -> [a]
exp f0 x =
   let x' = differentiate x
       y  = integrate (f0 (head x)) (mul y x')
   in  y

sinCos :: Field.C a => (a -> (a,a)) -> [a] -> ([a],[a])
sinCos f0 x =
   let (y0Sin, y0Cos) = f0 (head x)
       x'   = differentiate x
       ySin = integrate y0Sin         (mul yCos x')
       yCos = integrate y0Cos (negate (mul ySin x'))
   in  (ySin, yCos)

sinCosScalar :: Transcendental.C a => a -> (a,a)
sinCosScalar x = (Transcendental.sin x, Transcendental.cos x)

sin, cos :: Field.C a => (a -> (a,a)) -> [a] -> [a]
sin f0 = fst . sinCos f0
cos f0 = snd . sinCos f0

tan :: (Field.C a) => (a -> (a,a)) -> [a] -> [a]
tan f0 = uncurry divide . sinCos f0

{-
(log x)' == x'/x
(asin x)' == (acos x) == x'/sqrt(1-x^2)
(atan x)' == x'/(1+x^2)
-}

{- |
Input series must start with non-zero term.
-}
log :: (Field.C a) => (a -> a) -> [a] -> [a]
log f0 x = integrate (f0 (head x)) (derivedLog x)

{- |
Computes @(log x)'@, that is @x'\/x@
-}
derivedLog :: (Field.C a) => [a] -> [a]
derivedLog x = divide (differentiate x) x

atan :: (Field.C a) => (a -> a) -> [a] -> [a]
atan f0 x =
   let x' = differentiate x
   in  integrate (f0 (head x)) (divide x' ([1] + mul x x))

asin, acos :: (Field.C a) =>
   (a -> a) -> (a -> a) -> [a] -> [a]
asin sqrt0 f0 x =
   let x' = differentiate x
   in  integrate (f0 (head x))
                 (divide x' (sqrt sqrt0 ([1] - mul x x)))
acos = asin




instance (Transcendental.C a) =>
             Transcendental.C (T a) where
   pi = const NP.pi
   exp = lift1 (exp Transcendental.exp)
   sin = lift1 (sin sinCosScalar)
   cos = lift1 (cos sinCosScalar)
   tan = lift1 (tan sinCosScalar)
   x ** y = Transcendental.exp (Transcendental.log x * y)
                {- This order of multiplication is especially fast
                   when y is a singleton. -}
   log  = lift1 (log  Transcendental.log)
   asin = lift1 (asin Algebraic.sqrt Transcendental.asin)
   acos = lift1 (acos Algebraic.sqrt Transcendental.acos)
   atan = lift1 (atan Transcendental.atan)

{- |
It fulfills
  @ evaluate x . evaluate y == evaluate (compose x y) @
-}

compose :: (Ring.C a, ZeroTestable.C a) => T a -> T a -> T a
compose (Cons [])    (Cons []) = Cons []
compose (Cons (x:_)) (Cons []) = Cons [x]
compose (Cons x) (Cons (y:ys)) =
   if isZero y
     then Cons (comp x ys)
     else error "PowerSeries.compose: inner series must not have an absolute term."

{- |
Since the inner series must start with a zero,
the first term is omitted in y.
-}
comp :: (Ring.C a) => [a] -> [a] -> [a]
comp xs y = foldr (\x acc -> x : mul y acc) [] xs


{- |
Compose two power series where the outer series
can be developed for any expansion point.
To be more precise:
The outer series must be expanded with respect to the leading term
of the inner series.
-}
composeTaylor :: Ring.C a => (a -> [a]) -> [a] -> [a]
composeTaylor x (y:ys) = comp (x y) ys
composeTaylor x []     = x 0



{-
(x . y) = id
(x' . y) * y' = 1
y' = 1 / (x' . y)
-}

{- |
This function returns the series of the function in the form:
(point of the expansion, power series)

This is exceptionally slow and needs cubic run-time.
-}

inv :: (Field.C a) => [a] -> (a, [a])
inv x =
   let y' = divide [1] (comp (differentiate x) (tail y))
       y  = integrate 0 y'
            -- the first term is zero, which is required for composition
   in  (head x, y)