-- | Compute the non-equivariant CSM in @P^n@ recursively

{-# LANGUAGE BangPatterns, TypeSynonymInstances, FlexibleInstances #-}
module Math.RootLoci.CSM.Projective 
  ( -- * Pushforwards
    delta_star
  , pi_star
    -- * Easy things
  , tangentChernClass
  , smallestOrbitCSM
    -- * CSM calculation
  , upperCSM , lowerCSM
  , openCSM  , closedCSM
    -- * extracting coefficients
  , highestCoeff_ , lowestCoeff_
  , highestCoeff  , lowestCoeff 
  ) 
  where

--------------------------------------------------------------------------------

import Data.List
import Data.Maybe

import Math.Combinat.Numbers
import Math.Combinat.Sign
import Math.Combinat.Partitions.Integer
import Math.Combinat.Partitions.Set
import Math.Combinat.Sets

import qualified Data.Map as Map ; import Data.Map (Map)
import qualified Data.Set as Set ; import Data.Set (Set)

import Data.Array.IArray
import Data.Array (Array)

import Math.RootLoci.Algebra
import Math.RootLoci.Geometry
import Math.RootLoci.Misc

import qualified Math.RootLoci.Algebra.FreeMod as ZMod
 
--------------------------------------------------------------------------------

{-
  
we have maps
* Delta_nu : Q^d -> Q^n
* pi : Q^n -> P^n
  
-}

--------------------------------------------------------------------------------
-- * The order-forgetting map @pi : Q^n -> P^n@

pi_star_1 :: Int -> HS -> (G,Integer)
pi_star_1 n (HS hs) = (gk,c) where
  c  = factorial (n - length hs) 
  gk = G (length hs)

-- | The pushforward map @pi_*@ along @pi@.
--
-- A (cohomology) group generator above is a subset (=product) of H-s, which we map to
-- a group generator below. This defines the map on the cohomology ring by additive extension.
--
pi_star 
  :: Int           -- ^ the number of points @m@ (with multiplicity)
  -> ZMod HS       -- ^ the cohomoly class \"up\"
  -> ZMod G
pi_star n = ZMod.flatMap (sing . pi_star_1 n) where 
  sing (b,c) = ZMod.singleton b c

--------------------------------------------------------------------------------
-- * The diagonal maps @Delta_{\nu} : Q^d -> Q^n@
  
delta_star_1 :: Partition -> US -> ZMod HS
delta_star_1 part@(Partition ps) (US us) = ZMod.histogram almost where

  n = sum    ps
  d = length ps
  
  idxtable = linearIndices part
      
  -- inner lists = monoms
  -- outer lists = linear combination of monoms
  -- now we want to multiply those together
  stuff :: [[[H]]]
  stuff = (map . map . map) H (go 1 idxtable)
  
  almost :: [HS]
  almost = map (HS . concat) $ listTensor stuff     -- this does the multiplication of terms
  
  uis = [ i | U i <- us ]
    
  go :: Int -> [[Int]] -> [[[Int]]]
  go _ []       = []
  go k (is:iss) = this : go (k+1) iss where
    this = if k `elem` uis
      then [is]                     -- "sigma_k"
      else chooseN1 is              -- "sigma_(k-1)"
  
-- | A group generator on the left is a subset (=product) of U-s, which
-- we map to a linear combinaton of H-s. This is then extended additively
-- to the cohomology ring.
--
delta_star :: Partition -> ZMod US -> ZMod HS
delta_star part = ZMod.flatMap (delta_star_1 part)

--------------------------------------------------------------------------------
-- * Easy things

-- | The total Chern class of the tangent bundle of @Q^d = P^1 x P^1 x ... x P^1@
--
-- This is just the product of @(1+2u_i)@-s for @i=[1..d]@
--
tangentChernClass :: Int -> ZMod US
tangentChernClass d = ZMod.fromList $ concatMap worker [0..d] where
  worker k = map (\xs -> (US (map U xs) , 2^k)) (choose_ k d)

-- | The CSM of the smallest orbit: 1 point with multiplicity @n@,
-- which is just the rational normal curve in @P^n@.
--
smallestOrbitCSM :: Int -> ZMod G
smallestOrbitCSM n = ZMod.fromList 
  [ (G (n-1) ,     fromIntegral n) 
  , (G  n    , 2 * fromIntegral n) 
  ] 

--------------------------------------------------------------------------------
-- * CSM calculation

-- | We know that:
-- 
-- > csm(im(Delta) = Delta_* c(TQ^d)
-- > c(TQ^d) = (1+2*u1) (1+2*u2) ... (1+2*ud)
--
-- From these, we can compute @csm(im(Delta_nu))@ recursively
--
upperCSM :: Partition -> ZMod HS
upperCSM = pcache calc where

  calc part@(Partition ps) = (delta_star part) (tangentChernClass d) where
    d = length ps

-- | A formula for @pi_*(csm(im(delta)))@. This should satisfy
--
-- > lowerCSM part = pi_star n (upperCSM part)
--
lowerCSM :: Partition -> ZMod G
lowerCSM = pcache calc where

  calc part@(Partition ps) = zmod where
    d = length ps
    n = sum ps
    zmod = ZMod.fromList
      [ ( G (n-d+r) , coeff )
      | r<-[0..d]
      , let coeff = factorial (d-r) * 2^r * symPolyNum (d-r) (map fi ps)
      ]
  
    fi :: Int -> Integer
    fi = fromIntegral

check_lower_upper :: Int -> Bool
check_lower_upper n = and [ pi_star n (upperCSM p) == lowerCSM p | p <- partitions n ]

-- | Cached CSM computation of the open strata
openCSM :: Partition -> ZMod G
openCSM = pcache calcOpenCSM where

  -- | we know that (pi_* upperCSM) = sum (chi * openCSM)
  calcOpenCSM :: Partition -> ZMod G
  calcOpenCSM part = ZMod.invScale thisCoeff (pushdown - smaller) where
    n = partitionWeight part
    pushdown  = lowerCSM part -- pi_star n (upperCSM part) 
    smaller   = ZMod.linComb [ (c , openCSM q) | (q,c) <- Map.assocs theClosure ]
    (thisCoeff,theClosure) = preimageView part

-- | To get the CSM of the closed strata, we just sum over the open strata contained
-- in the closure.

closedCSM :: Partition -> ZMod G 
closedCSM = pcache calcClosedCSM where  

  calcClosedCSM :: Partition -> ZMod G
  calcClosedCSM part = ZMod.sum [ openCSM q | q <- Set.toList (closureSet part) ]

--------------------------------------------------------------------------------

lowestCoeff_ :: ZMod G -> Integer
lowestCoeff_ = snd . lowestCoeff

highestCoeff_ :: ZMod G -> Integer
highestCoeff_ = snd . highestCoeff

lowestCoeff :: ZMod G -> (G,Integer)
lowestCoeff = fromJust . ZMod.findMinTerm 
-- lowestCoeff = head . ZMod.toList 

highestCoeff :: ZMod G -> (G,Integer)
highestCoeff = fromJust . ZMod.findMaxTerm
-- highestCoeff = last . ZMod.toList 

--------------------------------------------------------------------------------

{-
check_degree :: Partition -> Bool
check_degree p = hilbert p == lowestCoeff_ (closedCSM p)

check_euler_degree :: Partition -> Bool
check_euler_degree p@(Partition ps) = hilbert p == ((csmToEuler n $ closedCSM p) !! d) where
  d = length ps
  n = sum ps
-}

--------------------------------------------------------------------------------