{-# OPTIONS -fno-implicit-prelude -fglasgow-exts #-}

{- |
Two-variate power series.

module MathObj.PowerSeries2 where

import qualified MathObj.PowerSeries    as PS
import qualified MathObj.Polynomial     as Poly

import qualified Algebra.Differential   as Differential
import qualified Algebra.Vector         as Vector
import qualified Algebra.Algebraic      as Algebraic
import qualified Algebra.Field          as Field
import qualified Algebra.Ring           as Ring
import qualified Algebra.Additive       as Additive
import qualified Algebra.ZeroTestable   as ZeroTestable

import qualified NumericPrelude as NP
import qualified PreludeBase as P

import Data.List (isPrefixOf)
import NumericPrelude.List (compareLength)

import PreludeBase    hiding (const)
import NumericPrelude hiding (negate, stdUnit,
                              sqrt, exp, log,
                              sin, cos, tan, asin, acos, atan)

{- |
In order to handle both variables equivalently
we maintain a list of coefficients for terms of the same total degree.
That is

> eval [[a], [b,c], [d,e,f]] (x,y) ==
>    a + b*x+c*y + d*x^2+e*x*y+f*y^2

Although the sub-lists are always finite and thus are more like polynomials than power series,
division and square root computation are easier to implement for power series.
newtype T a = Cons {coeffs :: Core a} deriving (Ord)

type Core a = [[a]]

isValid :: [[a]] -> Bool
isValid = flip isPrefixOf [1..] . map length

check :: [[a]] -> [[a]]
check xs =
   zipWith (\n x ->
      if compareLength n x == EQ
        then x
        else error "PowerSeries2.check: invalid length of sub-list")
     (iterate (():) [()]) xs

fromCoeffs :: [[a]] -> T a
fromCoeffs  =  Cons . check

fromPowerSeries0 :: Ring.C a => PS.T a -> T a
fromPowerSeries0 x =
   fromCoeffs $
   zipWith (:) (PS.coeffs x) $
   iterate (0:) []

fromPowerSeries1 :: Ring.C a => PS.T a -> T a
fromPowerSeries1 x =
   fromCoeffs $
   zipWith (++) (iterate (0:) []) $
   map (:[]) (PS.coeffs x)

lift0 :: Core a -> T a
lift0 = Cons

lift1 :: (Core a -> Core a) -> (T a -> T a)
lift1 f (Cons x0) = Cons (f x0)

lift2 :: (Core a -> Core a -> Core a) -> (T a -> T a -> T a)
lift2 f (Cons x0) (Cons x1) = Cons (f x0 x1)

lift0fromPowerSeries :: [PS.T a] -> Core a
lift0fromPowerSeries = map PS.coeffs

lift1fromPowerSeries :: ([PS.T a] -> [PS.T a]) -> (Core a -> Core a)
lift1fromPowerSeries f x0 = map PS.coeffs (f (map PS.fromCoeffs x0))

lift2fromPowerSeries :: ([PS.T a] -> [PS.T a] -> [PS.T a]) -> (Core a -> Core a -> Core a)
lift2fromPowerSeries f x0 x1 = map PS.coeffs (f (map PS.fromCoeffs x0) (map PS.fromCoeffs x1))

const :: a -> T a
const x = lift0 [[x]]

instance Functor T where
  fmap f (Cons xs) = Cons (map (map f) xs)

appPrec :: Int
appPrec  = 10

instance (Show a) => Show (T a) where
  showsPrec p (Cons xs) =
    showParen (p >= appPrec) (showString "PowerSeries2.fromCoeffs " . shows xs)

{- * Series arithmetic -}

add, sub :: (Additive.C a) => Core a -> Core a -> Core a
add = PS.add
sub = PS.sub

negate :: (Additive.C a) => Core a -> Core a
negate = PS.negate

instance (Eq a, ZeroTestable.C a) => Eq (T a) where
    (Cons x) == (Cons y) = Poly.equal x y

instance (Additive.C a) => Additive.C (T a) where
    negate = lift1 PS.negate
    (+)    = lift2 PS.add
    (-)    = lift2 PS.sub
    zero   = lift0 []

scale :: Ring.C a => a -> Core a -> Core a
scale = map . (Vector.*>)

mul :: Ring.C a => Core a -> Core a -> Core a
mul = lift2fromPowerSeries PS.mul

instance (Ring.C a) => Ring.C (T a) where
    one           = const one
    fromInteger n = const (fromInteger n)
    (*)           = lift2 mul

instance Vector.C T where
   zero  = zero
   (<+>) = (+)
   (*>)  = Vector.functorScale

divide :: (Field.C a) =>
   Core a -> Core a -> Core a
divide = lift2fromPowerSeries PS.divide

instance (Field.C a) => Field.C (T a) where
  (/) = lift2 divide

sqrt :: (Field.C a) =>
   (a -> a) -> Core a -> Core a
sqrt fSqRt = lift1fromPowerSeries $ PS.sqrt (PS.const . (\[x] -> fSqRt x) . PS.coeffs)

instance (Algebraic.C a) => Algebraic.C (T a) where
   sqrt   = lift1 (sqrt Algebraic.sqrt)
--   x ^/ y = lift1 (pow (Algebraic.^/ y)
--                       (fromRational' y)) x

swapVariables :: Core a -> Core a
swapVariables = map reverse

differentiate0 :: (Ring.C a) => Core a -> Core a
differentiate0 =
   swapVariables . differentiate1 . swapVariables

differentiate1 :: (Ring.C a) => Core a -> Core a
differentiate1 = lift1fromPowerSeries $ map Differential.differentiate

integrate0 :: (Field.C a) => [a] -> Core a -> Core a
integrate0 cs =
   swapVariables . integrate1 cs . swapVariables

integrate1 :: (Field.C a) => [a] -> Core a -> Core a
integrate1 = zipWith PS.integrate

{- |
Since the inner series must start with a zero,
the first term is omitted in y.
comp :: (Ring.C a) => [a] -> Core a -> Core a
comp = lift1fromPowerSeries . PS.comp . map PS.const