-- | The umbral formula for the open CSM classes.
--
-- The formula is the following:
--
-- > A(mu)    = 1 / aut(mu) * prod_i Theta(mu_i)
-- > Theta(p) = ( (1 + beta*s) (alpha+t)^p - (1 + alpha*s) (beta+t)^p ) / ( alpha - beta )
--
-- and the umbral subtitution resulting in the CSM class (at least for @length(mu)>=3@) is:
--
-- > t^j  ->  P_j(m)
-- > s^k  ->  (n-3)(n-3-1)(...n-3-k+1) * Q(n-3-k)
--
-- Note that Theta(p) is actually a (symmetric) polynomial in @alpha@ and @beta@; furthermore
-- it's linear in s and degree p in t. 

{-# LANGUAGE BangPatterns, TypeSynonymInstances, FlexibleInstances, ScopedTypeVariables #-}
module Math.RootLoci.CSM.Equivariant.Umbral where

--------------------------------------------------------------------------------

import Math.Combinat.Classes
import Math.Combinat.Numbers
import Math.Combinat.Partitions.Integer

import Data.Array.IArray

import qualified Data.Set as Set

import Math.RootLoci.Algebra
import Math.RootLoci.Geometry
import Math.RootLoci.Misc

import qualified Math.RootLoci.Algebra.FreeMod as ZMod

import Math.RootLoci.CSM.Equivariant.PushForward ( tau , piStarTableAff , piStarTableProj )
import Math.RootLoci.CSM.Equivariant.Ordered     ( formulaQPoly )

import qualified Math.RootLoci.CSM.Equivariant.Direct as Direct

--------------------------------------------------------------------------------
-- * The umbral variables

-- | A monomial @s^k * t^j@
data ST 
  = ST !Int !Int
  deriving (Eq,Ord,Show)

instance Monoid ST where
  mempty = ST 0 0 
  (ST s1 t1) `mappend` (ST s2 t2) = ST (s1+s2) (t1+t2)

instance Pretty ST where
  pretty st = case st of
    ST 0 0 -> "" 
    ST e 0 -> showVarPower "s" e
    ST 0 f -> showVarPower "t" f
    ST e f -> showVarPower "s" e ++ "*" ++ showVarPower "t" f

prettyMixedST :: forall b c. (Pretty b, Num c, Eq c, IsSigned c, Show c) => FreeMod (FreeMod c b) ST -> String
prettyMixedST = prettyFreeMod'' prettyInner pretty where

  prettyInner :: FreeMod c b -> String
  prettyInner = paren . pretty

--------------------------------------------------------------------------------
-- * The umbral formula

-- | @Theta(p)@ is defined by the formula
--
-- > Theta(p) = ( (1 + beta*s) (alpha+t)^p - (1 + alpha*s) (beta+t)^p ) / ( alpha - beta )
--
-- This is actually a polynomial in @alpha@,@beta@,@s@,@t@, also symmetric in @alpha@ and @beta@
--
theta :: ChernBase base => Int -> FreeMod (ZMod base) ST
theta p 
  | p >= 1    = ZMod.fromList (term0 ++ term1) 
  | otherwise = error "theta: non-positive input"
  where
 
    term0 =  [ (ST 0 i , ZMod.scale (binomial p i) (                         tau (p-i-1)) ) | i<-[0..p-1] ]
    term1 =  [ (ST 1 i , ZMod.scale (binomial p i) (ZMod.mulMonom c2_monom $ tau (p-i-2)) ) | i<-[0..p-2] ] 
          ++ [ (ST 1 p , ZMod.konst (-1) ) ]

    c2_monom = select0 (alphaBeta,c2)

-- | Same as 'theta' but with rational coefficients
thetaQ :: ChernBase b => Int -> FreeMod (QMod b) ST
thetaQ p = ZMod.mapCoeff (ZMod.mapCoeff fromIntegral) (theta p)

-------------------------------------------------------------------------------- 

-- | This is just @prod_i Theta_{mu_i}@
integralUmbralFormula :: ChernBase base => Partition -> FreeMod (ZMod base) ST 
integralUmbralFormula (Partition ps) = ZMod.product [ theta p | p <- ps ]

-- | This is @1/aut(mu) * prod_i Theta_{mu_i}@
umbralFormula :: ChernBase base => Partition -> FreeMod (QMod base) ST 
umbralFormula mu@(Partition ps) = result where
 
  result = ZMod.mapCoeff (ZMod.scale (1 / autmu))
         $ ZMod.product [ thetaQ p | p <- ps ]

  autmu :: Rational
  autmu = fromIntegral (aut mu)

--------------------------------------------------------------------------------
-- * The affine CSM

-- | The polynomial to be substituted in the place of @s^k*t^j@:
--
-- > s^k*t^j  ->  P_j(m) * Q_k(n-3-k) * (n-3)_k
--
-- where @n = length(mu)@ and @m = weight(mu)@.
--
umbralSubstPolyAff :: ChernBase base => Partition -> ST -> ZMod base
umbralSubstPolyAff part = fun where

  n = numberOfParts part
  m = weight part
  tablePPoly = piStarTableAff m

  fun (ST k j) 
    | k >= -3 && k <= n-3 && j >= 0 && j <= m  = ZMod.scale falling (qpoly `ZMod.mul` ppoly)
    | otherwise                                = ZMod.zero
    where
      falling :: Integer
      falling = product [ fromIntegral (n-3-i) | i<-[0..k-1] ]

      qpoly   = formulaQPoly (n-3-k)
      ppoly   = tablePPoly ! j

-- | The (affine) umbral substitution
umbralSubstitutionAff :: (ChernBase base) => Partition -> FreeMod (ZMod base) ST -> ZMod base
umbralSubstitutionAff part input = output where

  output   = ZMod.sum [ ab `ZMod.mul` (substfun st) | (st,ab) <- ZMod.toList input ]
  substfun = umbralSubstPolyAff part

-- | CSM of the open stratums from the umbral the formula
umbralAffOpenCSM :: ChernBase base => Partition -> ZMod base   
umbralAffOpenCSM = polyCache1 calc where

  -- the current umbral formula only works for @n >= 3@ ??
  calc mu 
    | n < 3     = forgetGamma (Direct.directOpenCSM mu)
    | otherwise = ZMod.invScale (aut mu)
                $ umbralSubstitutionAff mu
                $ integralUmbralFormula mu
    where
      n = numberOfParts mu

-- | Sum over the strata in the closure
umbralAffClosedCSM :: ChernBase base => Partition -> ZMod base   
umbralAffClosedCSM = polyCache1 calc where
  
  calc :: ChernBase base => Partition -> ZMod base
  calc part = ZMod.sum [ umbralAffOpenCSM q | q <- Set.toList (closureSet part) ] 

--------------------------------------------------------------------------------
-- * The projective CSM

-- | The polynomial to be substituted in the place of @s^k*t^j@:
--
-- > s^k*t^j  ->  P_j(m) * Q_k(n-3-k) * (n-3)_k
--
-- where @n = length(mu)@ and @m = weight(mu)@.
--
umbralSubstPolyProj :: forall base. ChernBase base => Partition -> ST -> ZMod (Gam base)
umbralSubstPolyProj part = fun where

  n = numberOfParts part
  m = weight part
  tablePPoly = piStarTableProj m

  fun (ST k j) 
    | k >= -3 && k <= n-3 && j >= 0 && j <= m  = ZMod.scale falling (qpoly `ZMod.mul` ppoly)
    | otherwise                                = ZMod.zero
    where
      falling :: Integer
      falling = product [ fromIntegral (n-3-i) | i<-[0..k-1] ]

      qpoly   = injectZMod (formulaQPoly (n-3-k)) :: ZMod (Gam base)
      ppoly   = tablePPoly ! j                    :: ZMod (Gam base)


-- | The (projective) umbral substitution
umbralSubstitutionProj :: (ChernBase base) => Partition -> FreeMod (ZMod base) ST -> ZMod (Gam base)
umbralSubstitutionProj part input = output where

  output   = ZMod.sum [ injectZMod ab `ZMod.mul` (substfun st) | (st,ab) <- ZMod.toList input ]
  substfun = umbralSubstPolyProj part

-- | CSM of the open stratums from the umbral the formula (for @length(mu) >= 3@)
umbralOpenCSM :: ChernBase base => Partition -> ZMod (Gam base)
umbralOpenCSM = polyCache2 calc where

  -- the current umbral formula only works for @n >= 3@ ??
  calc mu 
    | n < 3     = Direct.directOpenCSM mu     
    | otherwise = ZMod.invScale (aut mu)
                $ umbralSubstitutionProj mu
                $ integralUmbralFormula mu
    where
      n = numberOfParts mu

-- | Sum over the strata in the closure
umbralClosedCSM :: ChernBase base => Partition -> ZMod (Gam base)
umbralClosedCSM = polyCache2 calc where
  
  calc :: ChernBase base => Partition -> ZMod (Gam base)
  calc part = ZMod.sum [ umbralOpenCSM q | q <- Set.toList (closureSet part) ] 

--------------------------------------------------------------------------------