-- | Symmetric polynomials in two variables @alpha@ and @beta@.
--
-- We provide three representation:
--
-- * symmetric polynomials in @alpha@ and @beta@ (Chern roots)
--
-- * polynomials in the elementary symmetric polynomials @c1=alpha+beta@ and @c2=alpha*beta@ (Chern classes)
--
-- * Schur polynomials @s[i,j]@
--
-- The monomials of the first two of these form monoids (the product of 
-- monomials is again a monomial), and can be used uniformly with the
-- help of some type-level hackery.
--
-- How to use the unified interface?
-- Suppose you have a function like this:
--
-- > tau :: ChernBase base => Int -> ZMod base
--
-- When calling it, you want to specify the output type (either @ZMod AB@ or @ZMod Chern@).
-- You can do that three ways:
--
-- > x = tau @AB 10                  -- this needs -XTypeApplications
-- > x = (tau 10 :: ZMod AB)
-- > x = spec1' ChernRoot $ tau 10
--
-- The first one is the most convenient, but it only works with GHC 8 and later.
-- The other two work with older GHC versions, too.
--


{-# LANGUAGE DataKinds, TypeFamilies, Rank2Types, GADTs, StandaloneDeriving #-}
module Math.RootLoci.Algebra.SymmPoly where

--------------------------------------------------------------------------------

import Data.Proxy

import Math.Combinat.Sign
import Math.Combinat.Numbers

import qualified Data.Map.Strict as Map

import Control.Monad
import System.Random

import Math.RootLoci.Algebra.FreeMod (ZMod)
import qualified Math.RootLoci.Algebra.FreeMod as ZMod

import Math.RootLoci.Misc.Pretty

import Unsafe.Coerce as Unsafe

--------------------------------------------------------------------------------
-- * Base monomials

-- | Chern roots: @alpha^i * beta^j@, monomial base of @Z[alpha,beta]@
data AB = AB !Int !Int deriving (Eq,Ord,Show)

-- | Chern classes: @c1^i * c2^j@, monomial base of @Z[c1,c2]@
data Chern = Chern !Int !Int deriving (Eq,Ord,Show)

-- | Schur basis function: @S[i,j]@
data Schur = Schur !Int !Int deriving (Eq,Ord,Show) 

alpha, beta :: AB
alpha = AB 1 0 
beta  = AB 0 1    

--------------------------------------------------------------------------------

-- | @alpha * beta = c2@
alphaBeta :: AB
alphaBeta = AB 1 1    

-- | @c1 = alpha + beta@
c1 :: Chern
c1 = Chern 1 0     

-- | @c2 = alpha * beta@
c2 :: Chern
c2 = Chern 0 1     

--------------------------------------------------------------------------------
-- * Unified interface

-- | A singleton for distinguishing the two cases 
data Sing base where
  ChernRoot  :: Sing AB
  ChernClass :: Sing Chern

deriving instance Eq  (Sing base)
deriving instance Ord (Sing base)

-- | Common interface to work with Chern classes and Chern roots uniformly
class (Eq base, Ord base, Monoid base, Graded base, Pretty base) => ChernBase base where
  chernTag  :: base       -> Sing base
  chernTag1 :: f base     -> Sing base
  chernTag2 :: f (g base)     -> Sing base
  chernTag3 :: f (g (h base)) -> Sing base
  fromAB    :: ZMod AB    -> ZMod base  
  fromChern :: ZMod Chern -> ZMod base  
  fromSchur :: ZMod Schur -> ZMod base
  toAB      :: ZMod base  -> ZMod AB  
  toChern   :: ZMod base  -> ZMod Chern
  toSchur   :: ZMod base  -> ZMod Schur

instance ChernBase AB where
  chernTag  _ = ChernRoot
  chernTag1 _ = ChernRoot
  chernTag2 _ = ChernRoot
  chernTag3 _ = ChernRoot
  fromAB     = id
  fromChern  = chernToAB
  fromSchur  = schurToAB
  toAB       = id
  toChern    = abToChern
  toSchur    = abToSchur

instance ChernBase Chern where
  chernTag  _ = ChernClass
  chernTag1 _ = ChernClass
  chernTag2 _ = ChernClass
  chernTag3 _ = ChernClass
  fromAB     = abToChern
  fromChern  = id
  fromSchur  = schurToChern
  toAB       = chernToAB
  toChern    = id
  toSchur    = chernToSchur

--------------------------------------------------------------------------------
-- * Helper functions for constructing and specializing uniform things

-- | Constructing uniform things
select0 :: (AB, Chern) -> (ChernBase base => base)
select0 what = let final = select0' what (chernTag final) in final

select1 :: (f AB, f Chern) -> (ChernBase base => f base)
select1 what = let final = select1' what (chernTag1 final) in final

select2 :: (f (g AB), f (g Chern)) -> (ChernBase base => f (g base))
select2 what = let final = select2' what (chernTag2 final) in final

select3 :: (f (g (h AB)), f (g (h Chern))) -> (ChernBase base => f (g (h base)))
select3 what = let final = select3' what (chernTag3 final) in final

-- | Constructing unifom things using a tag
select0' :: (AB, Chern) -> (ChernBase base => Sing base -> base)
select0' (ab,ch) = \sing -> case sing of { ChernRoot -> ab ; ChernClass -> ch }

select1' :: (f AB, f Chern) -> (ChernBase base => Sing base -> f base)
select1' (ab,ch) = \sing -> case sing of { ChernRoot -> ab ; ChernClass -> ch }

select2' :: (f (g AB), f (g Chern)) -> (ChernBase base => Sing base -> f (g base))
select2' (ab,ch) = \sing -> case sing of { ChernRoot -> ab ; ChernClass -> ch }

select3' :: (f (g (h AB)), f (g (h Chern))) -> (ChernBase base => Sing base -> f (g (h base)))
select3' (ab,ch) = \sing -> case sing of { ChernRoot -> ab ; ChernClass -> ch }

-- | Specializing uniform things
spec0' :: ChernBase base => Sing base -> (forall b. ChernBase b => b) -> base
spec0' _ x = x

spec1' :: ChernBase base => Sing base -> (forall b. ChernBase b => f b) -> f base
spec1' _ x = x

spec2' :: ChernBase base => Sing base -> (forall b. ChernBase b => f (g b)) -> f (g base)
spec2' _ x = x

spec3' :: ChernBase base => Sing base -> (forall b. ChernBase b => f (g (h b))) -> f (g (h base))
spec3' _ x = x

{-
proxyOf :: a -> Proxy a
proxyOf _ = Proxy

proxyOf1 :: f a -> Proxy a
proxyOf1 _ = Proxy

proxyOf2 :: g (f a) -> Proxy a
proxyOf2 _ = Proxy
-}

--------------------------------------------------------------------------------

instance Monoid AB where
  mempty = AB 0 0 
  (AB a1 b1) `mappend` (AB a2 b2) = AB (a1+a2) (b1+b2)

instance Monoid Chern where
  mempty = Chern 0 0 
  (Chern e1 f1) `mappend` (Chern e2 f2) = Chern (e1+e2) (f1+f2)

instance Monoid Schur where
  mempty  = Schur 0 0
  mappend = error "Schur/mappend: not a monoid"

--------------------------------------------------------------------------------

instance Pretty AB where
  pretty ab = case ab of
    AB 0 0 -> "" 
    AB e 0 -> showVarPower "a" e
    AB 0 f -> showVarPower "b" f
    AB e f -> showVarPower "a" e ++ "*" ++ showVarPower "b" f
 
instance Pretty Chern where
  pretty (Chern 0 0) = ""
  pretty (Chern e 0) = showVarPower "c1" e
  pretty (Chern 0 f) = showVarPower "c2" f
  pretty (Chern e f) = showVarPower "c1" e ++ "*" ++ showVarPower "c2" f

instance Pretty Schur where
  pretty (Schur a b) 
    | b == 0     = "s[" ++ show a ++ "]"
    | otherwise  = "s[" ++ show a ++ "," ++ show b ++ "]"

--------------------------------------------------------------------------------
-- * Grading

class Graded a where
  grade :: a -> Int

instance Graded AB    where grade (AB    a b) = a + b
instance Graded Chern where grade (Chern e f) = e + 2*f
instance Graded Schur where grade (Schur i j) = i + j

filterGrade :: (Ord b, Graded b) => Int -> ZMod b -> ZMod b
filterGrade g = ZMod.onFreeMod filt where
  filt = Map.filterWithKey $ \x _ -> (grade x == g)

--------------------------------------------------------------------------------
-- * Conversions

chernToAB :: ZMod Chern -> ZMod AB 
chernToAB = ZMod.flatMap expandToAlphaBeta_1 where

  -- | c1^k * c2^n = (alpha+beta)^k * (alpha*beta)^n
  expandToAlphaBeta_1 :: Chern -> ZMod AB 
  expandToAlphaBeta_1 (Chern k n) = ZMod.fromList [ (AB (n+i) (n+k-i) , binomial k i) | i<-[0..k] ]

--------------------------------------------------------------------------------

-- | Converts a symmetric polynomial in the AB base (Chern roots) 
-- to the Chern base (elementary symmetric polynomials or Chern classes)
abToChern :: ZMod AB -> ZMod Chern
abToChern ab = case symmetricReduction ab of
  Right c -> c
  Left _  -> error "abToChern: input was not symmetric"

-- | @Left@ means there is a non-symmetric remainder; @Right@ means
-- that input was symmetric.
symmetricReduction :: ZMod AB -> Either (ZMod Chern, ZMod AB) (ZMod Chern)
symmetricReduction = go [] where

  go sofar zmod = case ZMod.findMaxTerm zmod of
    Nothing          -> Right q
    Just (AB n m, k) -> if n < m
      then Left (q,zmod)
      else go ((ch,k):sofar) (zmod - this) where
        ch   = Chern (n-m) m
        this = ZMod.scale k $ expandToAlphaBeta_1 ch
    where
      q = ZMod.fromList sofar

  -- | c1^k * c2^n = (alpha+beta)^k * (alpha*beta)^n
  expandToAlphaBeta_1 :: Chern -> ZMod AB 
  expandToAlphaBeta_1 (Chern k n) = ZMod.fromList [ (AB (n+i) (n+k-i) , binomial k i) | i<-[0..k] ]
            
--------------------------------------------------------------------------------

-- | Convert Schur to Chern roots
schurToAB :: ZMod Schur -> ZMod AB
schurToAB = ZMod.flatMap schurExpandAB_1 where

  schurExpandAB_1 :: Schur -> ZMod AB
  schurExpandAB_1 (Schur a b)
    | b > a     = error "schurExpandAB"
    | b < 0     = error "schurExpandAB"
    | otherwise = ZMod.fromList [ ( AB (a-j) (b+j) , 1 ) | j <- [0..a-b] ]

  {-
    schurab[i_, j_] := 
     Expand[Factor[ Det[{{a^(i + 1), b^(i + 1)}, {a^j, b^j}}]] / 
       Det[{{a, b}, {1, 1}}] ]
  -}

--------------------------------------------------------------------------------

-- | Convert Schur to Chern classes (elementary symmetric polynomials)
schurToChern :: ZMod Schur -> ZMod Chern
schurToChern = ZMod.flatMap schurExpandChern_1 where

  schurExpandChern_1 :: Schur -> ZMod Chern
  schurExpandChern_1 (Schur a b) 
    | b > a     = error "schurExpandChern_1"
    | b < 0     = error "schurExpandChern_1"
    | otherwise = ZMod.fromList [ ( Chern (a-b-2*j) (b+j) , paritySignValue j * binomial (a-b-j) j ) | j <- [0..(div (a-b) 2)] ]

  --  schurcd[i_, j_] := SymmetricReduction[schurab[i, j], {a, b}, {c1, c2}][[1]]

--------------------------------------------------------------------------------

chernToSchur :: ZMod Chern -> ZMod Schur
chernToSchur = ZMod.flatMap chernExpandSchur_1 where

  chernExpandSchur_1 :: Chern -> ZMod Schur
  chernExpandSchur_1 (Chern e f)
    | e < 0 || f < 0 = error "chernExpandSchur"
    | otherwise      = ZMod.fromList [ ( Schur (e+f-i) (f+i) , catalanTriangle (e-i) i) | i<-[0..(div e 2)] ]

--------------------------------------------------------------------------------

abToSchur :: ZMod AB -> ZMod Schur
abToSchur = chernToSchur . abToChern

chernToSchurNaive :: ZMod Chern -> ZMod Schur
chernToSchurNaive = ZMod.fromList . go where

  go zmod = case ZMod.findMaxTerm zmod of 
    Nothing             ->  []
    Just (Chern a b, k) -> ( s , k ) : go (zmod - this) where
      this = ZMod.scale k $ schurExpandChern_1 s
      s    = Schur (a+b) b

  schurExpandChern_1 :: Schur -> ZMod Chern
  schurExpandChern_1 (Schur a b) 
    | b > a     = error "schurExpandChern_1"
    | b < 0     = error "schurExpandChern_1"
    | otherwise = ZMod.fromList [ ( Chern (a-b-2*j) (b+j) , paritySignValue j * binomial (a-b-j) j ) | j <- [0..(div (a-b) 2)] ]

--------------------------------------------------------------------------------
-- * random polynomials for testing

randomChernMonom :: IO Chern
randomChernMonom = do
  a <- randomRIO (0,30)
  b <- randomRIO (0,15)
  return (Chern a b)

randomSchurMonom :: IO Schur
randomSchurMonom = do
  a <- randomRIO (0,30)
  b <- randomRIO (0,30)
  return (Schur (a+b) b)

withRandomCoeff :: IO a -> IO (a,Integer)
withRandomCoeff rnd = do
  k <- randomRIO (-100,100)
  x <- rnd
  return (x,k)

randomChernPoly :: IO (ZMod Chern)   
randomChernPoly = do
  n <- randomRIO (0,100)
  ZMod.fromList <$> replicateM n (withRandomCoeff randomChernMonom)

randomSchurPoly :: IO (ZMod Schur)   
randomSchurPoly = do
  n <- randomRIO (0,100)
  ZMod.fromList <$> replicateM n (withRandomCoeff randomSchurMonom)

--------------------------------------------------------------------------------