-- | CSM classes of the (open) strata in the set of /ordered/ @n@-tuples,
-- that is, @Q^n = P^1 x P^1 x ... x P^1@
--
-- Of special interest is the open stratum of distinct points, 
-- since any other stratum can be computed from that stratum 
-- by a simple push-forward.
-- 
-- The open stratum of distinct points can be computed recursively, 
-- since the full space @Q^n@ is the disjoint union of all stratums 
-- (indexed by /set partitions/).
-- 
-- But we also have a recursive formula, which makes the computation 
-- significantly faster.
--

{-# LANGUAGE BangPatterns, TypeSynonymInstances, FlexibleInstances,
             ScopedTypeVariables, Rank2Types, GADTs
  #-}

module Math.RootLoci.CSM.Equivariant.Ordered 
  ( -- * The product of projective lines @P^1 x ... x P^1@
    tangentChernClass
    -- * Diagonal embedding
  , j_star 
  , smallDiagonal
    -- * Recursive computation of the CSM of the strata
  , computeOpenStratumCSM     
  , computeAnyStratumCSM
  , computeClosureOfAnyStratumCSM
    -- * The structure lemma
  , QPow(..)
  , umbralDistinctFormula
  , umbralSubstQPow
  , computeQPolys
    -- * The recursive formula for the @Q_k(a,b)@ polynomials
  , formulaQPoly 
    -- * Formula for the CSM class of the stratum of distinct points
  , formulaDistinctCSM
  , formulaAnyStratumCSM
  ) 
  where

--------------------------------------------------------------------------------

import Math.Combinat.Classes
import Math.Combinat.Numbers
import Math.Combinat.Sign
import Math.Combinat.Partitions.Integer ( Partition(..) )
import Math.Combinat.Partitions.Set
import Math.Combinat.Sets

import Math.RootLoci.Algebra
import Math.RootLoci.Geometry
import Math.RootLoci.Misc

import qualified Data.Set as Set ; import Data.Set (Set)

import qualified Math.RootLoci.Algebra.FreeMod as ZMod

import Math.RootLoci.CSM.Equivariant.PushForward

--------------------------------------------------------------------------------
-- * The product of projective lines @P^1 x ... x P^1@

-- | Chern class of the tangent bundle of a product of projective lines.
--
-- The formula is:
--
-- > c(T(P^1 x P^1 ... x P^1)) = prod_i (1 + alpha + beta + 2*omega_i)
--
-- because
--
-- > c(T(PV)) = \prod_k (1 + w_i + omega)  `mod`  prod_k (w_i + omega) 
--
-- and
-- 
-- > (1+alpha+omega) * (1+beta+omega) = 1 + alpha + beta + 2*omega 
--
-- since the quadratic term is c_2 of a line bundle which is zero
--
tangentChernClass
  :: ChernBase base 
  => Int                  -- ^ the number of projective lines
  -> ZMod (Omega base)    -- ^ the tangent chern class of their product
tangentChernClass n = select2 
  ( tangentChernClassAB    n
  , tangentChernClassChern n
  )

tangentChernClassAB
  :: Int                  -- ^ The number of projective lines
  -> ZMod (Omega AB)
tangentChernClassAB d = ZMod.product [ entry i | i<-[1..d] ] where
  entry i = ZMod.fromList
    [ (Omega []  (AB 0 0) , 1)
    , (Omega []  (AB 1 0) , 1)
    , (Omega []  (AB 0 1) , 1)
    , (Omega [i] (AB 0 0) , 2)      -- 2x !
    ]

tangentChernClassChern
  :: Int                  -- ^ The number of projective lines
  -> ZMod (Omega Chern)
tangentChernClassChern d = ZMod.product [ entry i | i<-[1..d] ] where
  entry i = ZMod.fromList
    [ (Omega []  (Chern 0 0) , 1)
    , (Omega []  (Chern 1 0) , 1)
    , (Omega [i] (Chern 0 0) , 2)      -- 2x !
    ]

--------------------------------------------------------------------------------
-- * Diagonal embedding

-- | Diagonal embeddings of ordered products of P^1-s
j_star :: ChernBase base => [[Int]] -> ZMod (Omega base) -> ZMod (Omega base)
j_star indices = unsafeEtaToOmega . delta_star' indices where

-- | The CSM of the small diagonal in @P^1 x ... x P^1@
smallDiagonal :: forall base. ChernBase base => Int -> ZMod (Omega base)
smallDiagonal n = smallDiagonal' [1..n] where

  smallDiagonal' :: [Int] -> ZMod (Omega base)
  smallDiagonal' indices = j_star [indices] (tangentChernClass 1)

--------------------------------------------------------------------------------
-- * CSM of the strata

-- | Recursively compute the CSM of the Zariski-open set @U^n@ of distinct ordered points
-- in @Q^d = P^1 x ... x P^1@. We can compute this by we can subtract all the distinct 
-- fat diagonals from the Chern class of @Q^d@, and the diagonals are just pushforwards 
-- of the same thing for smaller @d@-s.
--
-- NOTE: We also have a more explicit formula for the result (which is /much/ faster to compute)
-- and we can compare the two.
--
-- Note: Forgetting the alpha\/beta part, this should equal to
--
-- > (1-h1-h2-...-hd)^(d-3)
--
-- But, remember that in this formula, @h_i^2 = 0@ for all i!
--
-- Including also @alpha@ and @beta@ we have instead the umbral formula
--
-- > (q-h1-h2-...-hd)^(d-3)
-- 
-- where we also have to do the umbral substitution @q^k -> Q_k@, and the polynomials @Q_k(alpha,beta)@ 
-- are defined recursively, and are defined for @k >= -3@.
--
computeOpenStratumCSM :: ChernBase base => Int -> ZMod (Omega base)
computeOpenStratumCSM = polyCache2 calcOpenStratumCSM  where
             
  calcOpenStratumCSM :: forall b. ChernBase b => Int -> ZMod (Omega b)
  calcOpenStratumCSM d
    | d == 0     =  ZMod.one 
    | d == 1     =  tangentChernClass 1
    | otherwise  = (tangentChernClass d) `ZMod.sub` (ZMod.sum diagonals)
    where
      diagonals = 
        [ computeAnyStratumCSM setp
        | setp <- setPartitions d 
        , let k = numberOfParts setp
        , k < d
        ]


-- | Simply the pushforward of the CSM of the open stratum along the
-- diagonal map corresponding to the given set partition 
--
computeAnyStratumCSM :: ChernBase base => SetPartition -> ZMod (Omega base)
computeAnyStratumCSM (SetPartition pps) = (j_star pps $ computeOpenStratumCSM $ length pps)

-- | We sum over the closure
computeClosureOfAnyStratumCSM :: ChernBase base => SetPartition -> ZMod (Omega base)
computeClosureOfAnyStratumCSM setp = ZMod.sum
  [ computeAnyStratumCSM p | p <- Set.toList (closureSetOfSetPartition setp) ] 

--------------------------------------------------------------------------------
-- * The structure lemma

-- | A formal monomial @q^k@
newtype QPow = QPow Int deriving (Eq,Ord,Show)

instance Monoid QPow where
  mempty = QPow 0
  mappend (QPow e) (QPow f) = QPow (e+f)

instance Pretty QPow where
  pretty (QPow k) = showVarPower "q" k

--------------------------------------------------------------------------------

-- | The umbral formula for the open stratum of the CSM of distinct ordered point:
--
-- > (q - u1 - u2 - ... - un)^(n-3)
--
-- where @u_i^2 = 1@. This also works @n = 0,1,2,3@
-- For these we have the expansion:
--
-- > (q - u1 - u2 - u3)^0   =  q^0
-- > (q - u1 - u2     )^-1  =  1/q + u1/q^2 + u2/q^2 + (2*u1*u2)/q^3
-- > (q - u1          )^-2  =  1/q^2 + (2*u1)/q^3
-- > (q               )^-3  =  1/q^3
--
umbralDistinctFormula :: Int -> ZMod (Omega QPow)
umbralDistinctFormula n
  | n <  0  = error "umbralDistinct: n should be nonnegative"
  | n == 0  = ZMod.generator $ monom [] (-3)
  | n == 1  = ZMod.fromList  
                [ (monom []    (-2) , 1) 
                , (monom [1]   (-3) , 2)
                ]
  | n == 2  = ZMod.fromList  
                [ (monom []    (-1) , 1)
                , (monom [1]   (-2) , 1)
                , (monom [2]   (-2) , 1)
                , (monom [1,2] (-3) , 2)
                ]
  | n >= 3  = ZMod.sum
                [ ZMod.scale coeff $ (ZMod.symPoly (n-3-k) us) * (ZMod.generator $ monom [] k)
                | k<-[0..n-3]
                , let coeff = negateIfOdd (n-3+k) (factorial (n-3) `div` factorial k)
                ]

  where
    monom xs k = Omega xs (QPow k)
    us = [ monom [i] 0 | i<-[1..n] ]

-- | Given a function specifying what to substitute in the place of @q^k@, we do the substitution.
umbralSubstQPow :: (ChernBase base) => (QPow -> ZMod base) -> ZMod (Omega QPow) -> ZMod (Omega base)
umbralSubstQPow subst1 input = ZMod.sum 
  [ ZMod.fromList 
      [ (Omega us ab , c*coeff) 
      | (ab,c) <- ZMod.toList (subst1 qpow) 
      ] 
  | (Omega us qpow , coeff) <- ZMod.toList input  
  ]

--------------------------------------------------------------------------------

-- | It is not hard to prove (by considering the pushforward along
-- the map forgetting one of the points), that the CSM of the locus
-- @U^n@ of the distinct points has the following form (for @n>=3@):
--
-- > csm(U^n) = sum_{k=0}^{n-3} \frac{(n-3)!}{k!} (-1)^{n-3-k} \sigma_{n-3-k}(u) Q_k(a,b)
-- 
-- We can already compute all CSM-s recursively, and from that information we can
-- determine these polynomials.
--
-- Which then we can compare with the recursive formula for the
-- polynomials itself (which is /much/ faster to evaluate)
--
computeQPolys :: Int -> ZMod AB
computeQPolys = icache' ZMod.zero (-3) calcComputeQPolys where

  calcComputeQPolys :: Int -> ZMod AB
  calcComputeQPolys n 
    | n <  -3    = error "computeQPolys: n >= -3 is required"
    | n == -3    = ZMod.one
    | otherwise  = ZMod.mapBase project almost
    where

      almost = open - smaller
      open   = computeOpenStratumCSM (n+3)     -- we should use this as the basis of the computation, unfortunately it's rather slow
    
      umbSmaller = umbralDistinctFormula (n+3) - umbHighest
      umbHighest = ZMod.generator (Omega [] (QPow n))        -- q^n
      smaller     = umbralSubstQPow (\(QPow k) -> computeQPolys k) umbSmaller

{-
      smaller = ZMod.sum 
        [ ZMod.scale coeff $ 
            (ZMod.symPoly (n-k) us) * (embed $ computeQPolys k)
        | k<-[0..n-1]
        , let coeff = negateIfOdd (n+k) (factorial n `div` factorial k)
        ]
      us = [ Omega [i] (AB 0 0) | i<-[1..n+3] ]
      embed = ZMod.mapBase $ \ab -> Omega [] ab
-}

      project (Omega us ab) = case us of
        [] -> ab
        _  -> error $ "computeQPolys: cannot project u terms:\n  " ++ pretty almost

--------------------------------------------------------------------------------
-- * The recursive formula for the @Q_k(a,b)@ polynomials

-- | The Fibonacci-type recursive formula for the @Q_k(a,b)@ polynomials
--
-- > Q_{-3} = 1
-- > Q_k    = Q_{k-1} * (1 - (k+1)*(a+b)) - Q_{k-2} * a*b * (k-1)*(k+2)
-- >        = Q_{k-1} * (1 - (k+1)* c_1 ) - Q_{k-2} * c_2 * (k-1)*(k+2)
--
-- We provide both the Chern root and the Chern class version in a uniform
-- way for convenience.
formulaQPoly :: ChernBase base => Int -> ZMod base
formulaQPoly n = select1 
  ( formulaQPolyAB   n 
  , formulaQPolyChern n
  )

formulaQPolyAB :: Int -> ZMod AB
formulaQPolyAB = icache' ZMod.zero (-3) calcQPoly where
  
  calcQPoly :: Int -> ZMod AB
  calcQPoly n
    | n <  -3   = ZMod.zero
    | n == -3   = ZMod.konst 1
    | otherwise = mult1 * prev1 + mult2 * prev2
    where
      prev1 = formulaQPolyAB (n-1)
      prev2 = formulaQPolyAB (n-2)

      Pair mult1 mult2 = qpolyRecursionCoeffs n

-- | Chern class version of the @Q_k@ formula (should be faster then the Chern root version, because the are less terms).
formulaQPolyChern :: Int -> ZMod Chern
formulaQPolyChern = icache' ZMod.zero (-3) calcQPoly where
  
  calcQPoly :: Int -> ZMod Chern
  calcQPoly n
    | n <  -3   = ZMod.zero
    | n == -3   = ZMod.konst 1
    | otherwise = mult1 * prev1 + mult2 * prev2
    where
      nn = fromIntegral n :: Integer

      prev1 = formulaQPolyChern (n-1)
      prev2 = formulaQPolyChern (n-2)

      Pair mult1 mult2 = qpolyRecursionCoeffs n

qpolyRecursionCoeffs :: ChernBase base => Int -> Pair (ZMod base)
qpolyRecursionCoeffs n = select2 
  (  Pair  mult1_AB    mult2_AB 
  ,  Pair  mult1_Chern mult2_Chern
  )
  where

    mult1_AB = ZMod.fromList 
      [ ( AB 0 0 ,     1 )
      , ( AB 1 0 , -nn-1 )
      , ( AB 0 1 , -nn-1 )
      ]
    mult2_AB = ZMod.singleton (AB 1 1) (-(nn-1)*(nn+2)) 
  
    mult1_Chern = ZMod.fromList 
      [ ( Chern 0 0 ,     1 )
      , ( Chern 1 0 , -nn-1 )
      ]
    mult2_Chern = ZMod.singleton (Chern 0 1) (-(nn-1)*(nn+2))

    nn = fromIntegral n :: Integer

--------------------------------------------------------------------------------
-- small @Q_k@ polynomials

{-
polyZMod :: ZMod AB -> (forall base. ChernBase base => ZMod base)
polyZMod ab = select1 (ab, abToChern ab)

-- | @Q_0 = ( 1 - a + b) ( 1 + a - b) = 1 - a^2 - b^2 + 2ab = 1 - c_1^2 + 4c_2@
konstQ0 :: ChernBase base => ZMod base
konstQ0 = polyZMod q0 where 
  q0 = ZMod.fromList [ ( AB 0 0 ,  1 )  , ( AB 2 0 , -1 )  , ( AB 0 2 , -1 )  , ( AB 1 1 ,  2 )  ]  

-- | @Q_-1 = 1 + a + b + 2 a*b = 1 + c_1 + 2c_2@
konstQminus1 :: ChernBase base => ZMod base
konstQminus1 = polyZMod qminus1 where
  qminus1 = ZMod.fromList [ ( AB 0 0 ,  1 ) ,  ( AB 1 0 ,  1 )  , ( AB 0 1 ,  1 )  , ( AB 1 1 ,  2 ) ]

-- | @Q_-2 = 1 + a + b = 1 + c_1@
konstQminus2 :: ChernBase base => ZMod base
konstQminus2 = polyZMod qminus2 where
  qminus2 = ZMod.fromList [ ( AB 0 0 ,  1 ) , ( AB 1 0 ,  1 ) , ( AB 0 1 ,  1 ) ]

-- | @Q_-3 = 1@
konstQminus3 :: ChernBase base => ZMod base
konstQminus3 = ZMod.konst 1
-}

--------------------------------------------------------------------------------
-- * Formula for the CSM class of the stratum of distinct points

-- | The formula for the CSM of the set of distinct ordered points
-- using the formula for the Q_k(a,b) polynomials above
--
formulaDistinctCSM :: ChernBase base => Int -> ZMod (Omega base)
formulaDistinctCSM n 
  | n < 0     = error "formulaDistinctCSM: dimension should be nonnegative"
  | otherwise = umbralSubstQPow fun 
              $ umbralDistinctFormula n
  where
    fun (QPow k) = formulaQPoly k
{-
  | n < 3     = computeOpenStratumCSM n
  | otherwise = ZMod.sum 
      [ ZMod.scale coeff poly
      | k <- [0..n-3] 
      , let coeff = paritySignValue (n-3-k) * div (factorial (n-3)) (factorial k)
      , let qk    = formulaQPoly k
      , let sym   = choose (n-3-k) [1..n]
      , let poly  = ZMod.fromList [ (Omega xs ab, k) | xs <- sym, (ab,k) <- ZMod.toList qk ]
      ]
-}

-- | Just the pushforward of the previous along @Delta_mu@
formulaAnyStratumCSM :: ChernBase base => SetPartition -> ZMod (Omega base)
formulaAnyStratumCSM setp = unsafeEtaToOmega $ delta_star setp (formulaDistinctCSM k) where
  k = numberOfParts setp
  
--------------------------------------------------------------------------------