-- | Formula for the dual cohomology class of the /cones/ over the strata (sometimes called Thom polynomial) 
-- in terms of the Chern classes @c1@ and @c2@, from the author's MSc thesis.
--
-- Note that the dual class agress with the lowest degree part of the CSM class.
--
-- See: Balazs Komuves: Thom Polynomials via Restriction Equations; MSc thesis, ELTE, 2003
--

{-# LANGUAGE BangPatterns, TypeSynonymInstances, FlexibleInstances, ScopedTypeVariables #-}
module Math.RootLoci.Dual.Restriction where

--------------------------------------------------------------------------------

import Data.List
import Data.Ratio

import Control.Monad

import Math.Combinat.Numbers
import Math.Combinat.Sign
import Math.Combinat.Partitions.Integer
import Math.Combinat.Partitions.Set
import Math.Combinat.Sets
import Math.Combinat.Tuples

import qualified Data.Set as Set ; import Data.Set (Set)

import qualified Math.RootLoci.Algebra.FreeMod as ZMod

import Math.RootLoci.Algebra
import Math.RootLoci.Classic
import Math.RootLoci.Geometry
import Math.RootLoci.Misc

--------------------------------------------------------------------------------
-- * The dual class

-- | The affine Thom polynomial formula from my MSc thesis
affineDualMSc :: Partition -> ZMod Chern
affineDualMSc part@(Partition ps) = 

  case ps of
    []            -> error "affine_tp_msc: empty partition"
    [n]           -> ZMod.fromList [ ( Chern (n-d-2*j) j , rat2int $ single j ) | j<-[ 0 .. div (n-d) 2] ] 
    [a,b] | a==b  -> ZMod.fromList [ ( Chern (n-d-2*j) j , rat2int $ double j ) | j<-[ 0 .. div (n-d) 2] ] 
    otherwise     -> ZMod.fromList [ ( Chern (n-d-2*j) j , rat2int $ lambda j ) | j<-[ 0 .. div (n-d) 2] ] 

  where

    n = sum ps
    d = length ps

    p = div  n    2
    q = div (n-1) 2

    rat2int r = case denominator r of
      1 -> numerator r
      _ -> error "lambda_j: not integer"

    lambda j = (fi n / 2)^(n-2*q) * fi (doubleFactorial (n-2))^2 * s where
      s = sum
        [ negateIfOdd (n + p + j + lpsi) $ bigTheta j nphi * (fi (2*nphi-n) / fi n)^(d-2) / (fi $ aut phi * aut psi)
        | (phi,psi) <- Set.toList (divideIntoTwoNonEmpty part)
        , let nphi = sum $ fromPartition phi
        , let npsi = sum $ fromPartition psi
        , let lphi = length $ fromPartition phi
        , let lpsi = length $ fromPartition psi
        ] 

    gamma :: Int -> Rational
    gamma k 
      | 2*k == n   = 0 
      | otherwise  = fi (k*(k-n)) / fi ((2*k-n)*(2*k-n))

    bigTheta :: Int -> Int -> Rational
    bigTheta j k 
      | 2*k == n   = 0 
      | otherwise  = gamma k * smallTheta j k

    smallTheta :: Int -> Int -> Rational
    smallTheta j k = sympoly (q-1-j) [ gamma i | i<-[1..q] , i/=k, i/=n-k ]
   
    fi :: Integral a => a -> Rational
    fi = fromIntegral

    sqj :: Int -> Rational
    sqj j = sympoly (q-j) [ gamma i | i<-[1..q] ]

    sympoly :: Int -> [Rational] -> Rational
    sympoly k xs = sum [ product ys | ys <- choose k xs ]

    -- S(n)
    single j = fi (factorial n) / (product [ gamma i | i<-[1..q] ])
             * negateIfOdd j (sqj j) 

    -- S(p,p)
    double j = fi (doubleFactorial n)^2 / 4 
             * negateIfOdd (q+j) (sqj j) 


--------------------------------------------------------------------------------
-- * Degree

-- | Compute the projective degree from the affine equivariant dual 
-- (which can be checked against Hilbert's formula)
-- 
-- This is just a simple substition:
--
-- > alpha  ->  1/n
-- > beta   ->  1/n
--
-- or in terms of Chern classes:
--
-- > c1     ->  2/n
-- > c2     ->  1/n^2
--
projDegreeFromDual
  :: Int             -- ^ number of points = dimension of the projective space @P^n@
  -> ZMod Chern      -- ^ dual class
  -> Integer         -- ^ degree
projDegreeFromDual n zm = fromRat s where 

  s :: Rational
  s = sum [ fromIntegral c * c1^e * c2^f  | (Chern e f, c) <- ZMod.toList zm ]

  c1 = 2 / fromIntegral  n    :: Rational
  c2 = 1 / fromIntegral (n*n) :: Rational

-- | Compute the degree of the strata via the formula for the dual class
degreeMSc :: Partition -> Integer
degreeMSc part = projDegreeFromDual (partitionWeight part) (affineDualMSc part)

{-

check_msc_degree :: Bool
check_msc_degree = and
  [ msc_degree part == hilbert part | n<-[1..12] , part <- partitions n ]
-}

--------------------------------------------------------------------------------
-- * extract the dual class from the CSM class 

-- | The dual class of the closure agress with the lowest degree part of the CSM class.
dualClassFromProjCSM :: forall base. ChernBase base => ZMod (Gam base) -> ZMod base
dualClassFromProjCSM csm = dualClassFromAffCSM (ZMod.filterBase nogamma csm) where
  nogamma :: Gam base -> Maybe base
  nogamma (Gam k ab) = if k==0 then Just ab else Nothing

dualClassFromAffCSM :: ChernBase base => ZMod base -> ZMod base
dualClassFromAffCSM csm = filterGrade min_degree csm where
  min_degree = minimum $ map grade $ map fst $ ZMod.toList csm

--------------------------------------------------------------------------------
-- * Lemma 9.1.3

{-
test_lemma_913 = and
  [ lemma913 p h 
  | n<-[1..10] 
  , p@(Partition ps)<-partitions n
  , let d=length ps
  , h<-[0..d]
  ]

test_lemma_913' =  
  [ (lemma913' p h,(p,h),(d,n))
  | n<-[1..10] 
  , p@(Partition ps)<-partitions n
  , let d=length ps
  , h<-[0..d]
  ]
-}

-- | Checks if Lemma 9.1.3 from the thesis is true for the given inputs
lemma913 :: Partition -> Int -> Bool
lemma913 part h = (a==b) where 
  (a,b) = lemma913' part h 

  lemma913' :: Partition -> Int -> (Rational, Rational)
  lemma913' part@(Partition ps) h = ( lhs , rhs ) where

    n = sum ps
    d = length ps

    rhs | h == d  = tr (factorial d) * product (map fi ps)
        | h <  d  = 0
        | h >  d  = -666

    lhs = sum
      [ negateIfOdd (length rs) $  (fi (2 * sum qs - n) / 2)^h * (tr $ aut part) / (tr $ aut phi * aut psi)
      | ( phi@(Partition qs) , psi@(Partition rs) ) <- Set.toList (divideIntoTwo part)
      ]

    fi :: Int -> Rational
    fi = fromIntegral

    tr :: Integer -> Rational
    tr = fromIntegral  


--------------------------------------------------------------------------------
-- * helper functions

-- | Different ways to divide a partition into two 
divideIntoTwo :: Partition -> Set (Partition,Partition)
divideIntoTwo (Partition ps) = Set.fromList $ map f (binaryTuples d) where

  d    = length ps
  f ts = ( g ts , g (map not ts) )
  g ts = Partition [ k | (b,k) <- zip ts ps , b ]

  -- nonempty (p,q) = not (isEmptyPartition p) && not (isEmptyPartition q)

-- | Different ways to divide a partition into two /nonempty/ partitions
divideIntoTwoNonEmpty :: Partition -> Set (Partition,Partition)
divideIntoTwoNonEmpty p = Set.delete x $ Set.delete y $ divideIntoTwo p where
  x = (emptyPartition,p)
  y = (p,emptyPartition)

--------------------------------------------------------------------------------