{-# LANGUAGE PatternGuards #-}

-- | Monomials in a countably infinite set of variables x1, x2, x3, ...
module MathObj.Monomial
    ( -- * Type
      T(..)

      -- * Creating monomials
    , mkMonomial
    , constant
    , x

      -- * Utility functions
    , degree
    , pDegree
    , scaleMon

    ) where

import qualified Algebra.Additive as Additive
import qualified Algebra.Ring as Ring
import qualified Algebra.ZeroTestable as ZeroTestable
import qualified Algebra.Differential as Differential
import qualified Algebra.Field as Field

import qualified Data.Map as M
import Data.Ord (comparing)
import Control.Arrow ((***))
import Data.List (sort, intercalate)

import NumericPrelude
import PreludeBase

-- | A monomial is a map from variable indices to integer powers,
--   paired with a (polymorphic) coefficient.  Note that negative
--   integer powers are handled just fine, so monomials form a field.
--
--   Instances are provided for Eq, Ord, ZeroTestable, Additive, Ring,
--   Differential, and Field.  Note that adding two monomials only
--   makes sense if they have matching variables and exponents.  The
--   Differential instance represents partial differentiation with
--   respect to x1.
--
--   The Ord instance for monomials orders them first by permutation
--   degree, then by largest variable index (largest first), then by
--   exponent (largest first).  This may seem a bit odd, but in fact
--   reflects the use of these monomials to implement cycle index
--   series, where this ordering corresponds nicely to generation
--   of integer partitions. To make the library more general we could
--   parameterize monomials by the desired ordering.
data T a = Cons { coeff  :: a 
                , powers :: M.Map Integer Integer
                }

mkMonomial :: a -> [(Integer, Integer)] -> T a
mkMonomial a p = Cons a (M.fromList p)

instance (ZeroTestable.C a, Ring.C a, Eq a, Show a) => Show (T a) where
  show (Cons a pows) | isZero a    = "0"
                     | M.null pows = show a
                     | a == 1      = showVars pows
                     | a == (-1)   = "-" ++ showVars pows
                     | otherwise   = show a ++ " " ++ showVars pows

showVars :: M.Map Integer Integer -> String
showVars m = intercalate " " $ concatMap showVar (M.toList m)
  where showVar (_,0) = []
        showVar (v,1) = ["x" ++ show v]
        showVar (v,p) = ["x" ++ show v ++ "^" ++ show p]

-- | The degree of a monomial is the sum of its exponents.
degree :: T a -> Integer
degree (Cons _ m) = M.fold (+) 0 m

-- | The \"partition degree\" of a monomial is the sum of the products
--   of each variable index with its exponent.  For example, x1^3 x2^2
--   x4^3 has partition degree 1*3 + 2*2 + 4*3 = 19.  The terminology
--   comes from the fact that, for example, we can view x1^3 x2^2 x4^3
--   as corresponding to an integer partition of 19 (namely, 1 + 1 + 1
--   + 2 + 2 + 4 + 4 + 4).
pDegree :: T a -> Integer
pDegree (Cons _ m) = sum . map (uncurry (*)) . M.assocs $ m

-- | Create a constant monomial.
constant :: a -> T a
constant a = Cons a M.empty

-- | Create the monomial xn for a given n.
x :: (Ring.C a) => Integer -> T a
x n = Cons Ring.one (M.singleton n 1)

-- | Scale all the variable subscripts by a constant.  Useful for
--   operations like plethyistic substitution or Mobius inversion.
scaleMon :: Integer -> T a -> T a
scaleMon n (Cons a m) = Cons a (M.mapKeys (n*) m)

instance Eq (T a) where
  (Cons _ m1) == (Cons _ m2) = m1 == m2

instance Ord (T a) where
  compare m1 m2
    | d1 < d2   = LT
    | d1 > d2   = GT
    | otherwise = comparing q m1 m2
    where d1 = pDegree m1
          d2 = pDegree m2
          q  = map Rev . reverse . sort . M.assocs . powers

newtype Rev a = Rev { getRev :: a }
  deriving Eq
instance Ord a => Ord (Rev a) where
  compare (Rev a) (Rev b) = compare b a

instance (ZeroTestable.C a) => ZeroTestable.C (T a) where
  isZero (Cons a _) = isZero a
  
instance (Additive.C a, ZeroTestable.C a) => Additive.C (T a) where
  zero = Cons zero M.empty
  negate (Cons a m) = Cons (negate a) m

  -- precondition: m1 == m2
  (Cons a1 m1) + (Cons a2 _m2) | isZero s  = Cons s M.empty
                               | otherwise = Cons s m1
                               where s = a1 + a2

instance (Ring.C a, ZeroTestable.C a) => Ring.C (T a) where
  fromInteger n = Cons (fromInteger n) M.empty
  (Cons a1 m1) * (Cons a2 m2) = Cons (a1*a2) 
                                     (M.filterWithKey (\_ p -> not (isZero p)) $
                                        M.unionWith (+) m1 m2
                                     )

-- Partial differentiation with respect to x1.
instance (ZeroTestable.C a, Ring.C a) => Differential.C (T a) where
  differentiate (Cons a m) 
    | Just 1 <- M.lookup 1 m = Cons a M.empty
    | Just p <- M.lookup 1 m = Cons (a*fromInteger p) (M.adjust (subtract 1) 1 m)
    | otherwise              = Cons 0 M.empty

instance (ZeroTestable.C a, Field.C a, Eq a) => Field.C (T a) where
  recip (Cons 0 _)    = error "Monomial.recip: division by zero"
  recip (Cons a pows) = Cons (recip a) (M.map negate pows)