-- | Polynomials in a countably infinite set of variables x1, x2, x3, ...
module MathObj.MultiVarPolynomial
    ( -- * Type
      T(..)

      -- * Constructing polynomials
    , fromMonomials
    , lift0
    , lift1
    , lift2
    , x
    , constant

      -- * Operations

    , compose

    , merge

    ) where

import qualified Algebra.Additive as Additive
import qualified Algebra.Ring as Ring
import qualified Algebra.ZeroTestable as ZeroTestable
import qualified Algebra.Differential as Differential

import qualified MathObj.Monomial as Mon

import qualified Data.Map as M

import NumericPrelude
import PreludeBase

-- | A polynomial is just a list of monomials, construed as their sum.
--   We maintain the invariant that polynomials are always sorted by
--   the ordering on monomials defined in "MathObj.Monomial": first by
--   partition degree, then by largest variable index (decreasing),
--   then by exponent of the highest-index variable (decreasing).
--   This works out nicely for operations on cycle index series.
--
--   Instances are provided for Additive, Ring, Differential
--   (partial differentiation with respect to x1), and Show.
newtype T a = Cons [Mon.T a]

instance (ZeroTestable.C a, Ring.C a, Ord a, Show a) => Show (T a) where
  show (Cons []) = "0"
  show (Cons (m:ms)) = show m ++ concatMap showMon ms
    where showMon m | Mon.coeff m < 0 = " - " ++ show (negate m)
                    | otherwise       = " + " ++ show m

{-# INLINE fromMonomials #-}
fromMonomials :: [Mon.T a] -> T a
fromMonomials = lift0

{-# INLINE lift0 #-}
lift0 :: [Mon.T a] -> T a
lift0 = Cons

{-# INLINE lift1 #-}
lift1 :: ([Mon.T a] -> [Mon.T a]) -> (T a -> T a)
lift1 f (Cons xs) = Cons (f xs)

{-# INLINE lift2 #-}
lift2 :: ([Mon.T a] -> [Mon.T a] -> [Mon.T a]) -> (T a -> T a -> T a)
lift2 f (Cons xs) (Cons ys) = Cons (f xs ys)

-- | Create the polynomial xn for a given n.
x :: (Ring.C a) => Integer -> T a
x n = fromMonomials [Mon.x n]

-- | Create a constant polynomial.
constant :: a -> T a
constant a = fromMonomials [Mon.constant a]

-- | Add two polynomials.  We assume that they are already sorted, so
--   that addition works on infinite polynomials.
add :: (Ord a, Additive.C a) => [a] -> [a] -> [a]
add xs ys = merge True (+) xs ys

-- | Merge two sorted lists, with a flag specifying whether to keep
--   singletons, and a combining function for elements that are equal.
merge :: Ord a => Bool -> (a -> a -> a) -> [a] -> [a] -> [a]
merge True  _ [] ys = ys
merge False _ [] _  = []
merge True  _ xs [] = xs
merge False _ _  [] = []
merge b f xxs@(x:xs) yys@(y:ys) | x < y     = if' b (x:) id $ merge b f xs yys
                                | x > y     = if' b (y:) id $ merge b f xxs ys
                                | otherwise = f x y : merge b f xs ys
  where if' True x _ = x
        if' False _ y = y

instance (Additive.C a, ZeroTestable.C a) => Additive.C (T a) where
  zero   = fromMonomials []
  negate = lift1 $ map negate
  (+)    = lift2 add

-- | Multiply two (sorted) polynomials.
mul :: (Ring.C a, Ord a) => [a] -> [a] -> [a]
mul [] _ = []
mul _ [] = []
mul (x:xs) (y:ys) = x*y : add (map (x*) ys) (mul xs (y:ys))

instance (Ring.C a, ZeroTestable.C a) => Ring.C (T a) where
  fromInteger n = fromMonomials [fromInteger n]
  (*) = lift2 mul

-- Partial differentiation with respect to x1.
instance (ZeroTestable.C a, Ring.C a) => Differential.C (T a) where
  differentiate = lift1 $ filter (not . isZero) . map Differential.differentiate

-- | Plethyistic substitution: F o G = F(G(x1,x2,x3...),
--   G(x2,x4,x6...), G(x3,x6,x9...), ...)  See Bergeron, Labelle, and
--   Leroux, \"Combinatorial Species and Tree-Like Structures\",
--   p. 43.
compose :: (Ring.C a, ZeroTestable.C a) => T a -> T a -> T a
compose (Cons []) _ = Cons []
compose (Cons (x:_)) (Cons []) = Cons [x]
compose (Cons xs) yys@(Cons (y:ys))
  | Mon.degree y == 0 && (not . isZero . Mon.coeff $ y)
    = error $ "MultiVarPolynomial.compose: inner series must not have a constant term."
  | otherwise = comp xs yys

-- | We need to be careful to make sure this is suitably
--   lazy. For example, this works for finite polynomials:
--
-- > comp ms p = sum . map (substMon p) $ ms
--
--   but not for infinite ones!
--
--   This is accomplished by calling a recursive helper function
--   taking as an extra argument a running sum containing only terms
--   with partition degree greater than or equal to the most recently
--   processed monomial.  Plethyistically substituting a polynomial
--   (with no constant term) into a monomial of partition degree d
--   produces a polynomial with all terms of partition degree >= d, so
--   when we encounter a monomial with partition degree d, we know we
--   are done with all terms in the running sum of lesser partition
--   degree.
--
--   Precondition: the second argument has no constant term.
comp :: (Ring.C a, ZeroTestable.C a) => [Mon.T a] -> T a -> T a
comp ms p = comp' 0 ms
  where -- comp' :: T a -> [Mon.T a] -> T a
        comp' part []     = part
        comp' part (m:ms) = lift2 (++) done $ comp' (rest + substMon p m) ms
          where (done,rest) = splitPoly ((< Mon.pDegree m) . Mon.pDegree) part

-- | Plethyistic substitution of a polynomial into a monomial.
substMon :: (ZeroTestable.C a, Ring.C a) => T a -> Mon.T a -> T a
substMon poly m
  = (constant (Mon.coeff m) *)
  . M.foldWithKey (\sub pow -> (*) (scalePoly sub poly ^pow)) 1
  $ Mon.powers m

-- | @scalePoly n Z@ changes Z(x_1, x_2, x_3, ...) into Z(x_n, x_2n, x_3n, ...)
scalePoly :: Integer -> T a -> T a
scalePoly n = lift1 $ map (Mon.scaleMon n)

-- | Split a polynomial into two pieces based on a predicate.
splitPoly :: (Mon.T a -> Bool) -> T a -> (T a, T a)
splitPoly p (Cons xs) = (Cons ys, Cons zs)
  where (ys, zs) = span p xs