-- | Symmetric polynomials in two variables @alpha@ and @beta@.
-- We provide three representation:
-- * symmetric polynomials in @alpha@ and @beta@ (Chern roots)
-- * polynomials in the elementary symmetric polynomials @c1=alpha+beta@ and @c2=alpha*beta@ (Chern classes)
-- * Schur polynomials @s[i,j]@
-- The monomials of the first two of these form monoids (the product of 
-- monomials is again a monomial), and can be used uniformly with the
-- help of some type-level hackery.
-- How to use the unified interface?
-- Suppose you have a function like this:
-- > tau :: ChernBase base => Int -> ZMod base
-- When calling it, you want to specify the output type (either @ZMod AB@ or @ZMod Chern@).
-- You can do that three ways:
-- > x = tau @AB 10                  -- this needs -XTypeApplications
-- > x = (tau 10 :: ZMod AB)
-- > x = spec1' ChernRoot $ tau 10
-- The first one is the most convenient, but it only works with GHC 8 and later.
-- The other two work with older GHC versions, too.

{-# LANGUAGE DataKinds, TypeFamilies, Rank2Types, GADTs, StandaloneDeriving #-}
module Math.RootLoci.Algebra.SymmPoly where


import Data.Proxy

import Math.Combinat.Sign
import Math.Combinat.Numbers

import qualified Data.Map.Strict as Map

import Control.Monad
import System.Random

import Math.RootLoci.Algebra.FreeMod (ZMod)
import qualified Math.RootLoci.Algebra.FreeMod as ZMod

import Math.RootLoci.Misc.Pretty

import Unsafe.Coerce as Unsafe

-- * Base monomials

-- | Chern roots: @alpha^i * beta^j@, monomial base of @Z[alpha,beta]@
data AB = AB !Int !Int deriving (Eq,Ord,Show)

-- | Chern classes: @c1^i * c2^j@, monomial base of @Z[c1,c2]@
data Chern = Chern !Int !Int deriving (Eq,Ord,Show)

-- | Schur basis function: @S[i,j]@
data Schur = Schur !Int !Int deriving (Eq,Ord,Show) 

alpha, beta :: AB
alpha = AB 1 0 
beta  = AB 0 1    


-- | @alpha * beta = c2@
alphaBeta :: AB
alphaBeta = AB 1 1    

-- | @c1 = alpha + beta@
c1 :: Chern
c1 = Chern 1 0     

-- | @c2 = alpha * beta@
c2 :: Chern
c2 = Chern 0 1     

-- * Unified interface

-- | A singleton for distinguishing the two cases 
data Sing base where
  ChernRoot  :: Sing AB
  ChernClass :: Sing Chern

deriving instance Eq  (Sing base)
deriving instance Ord (Sing base)

-- | Common interface to work with Chern classes and Chern roots uniformly
class (Eq base, Ord base, Monoid base, Graded base, Pretty base) => ChernBase base where
  chernTag  :: base       -> Sing base
  chernTag1 :: f base     -> Sing base
  chernTag2 :: f (g base)     -> Sing base
  chernTag3 :: f (g (h base)) -> Sing base
  fromAB    :: ZMod AB    -> ZMod base  
  fromChern :: ZMod Chern -> ZMod base  
  fromSchur :: ZMod Schur -> ZMod base
  toAB      :: ZMod base  -> ZMod AB  
  toChern   :: ZMod base  -> ZMod Chern
  toSchur   :: ZMod base  -> ZMod Schur

instance ChernBase AB where
  chernTag  _ = ChernRoot
  chernTag1 _ = ChernRoot
  chernTag2 _ = ChernRoot
  chernTag3 _ = ChernRoot
  fromAB     = id
  fromChern  = chernToAB
  fromSchur  = schurToAB
  toAB       = id
  toChern    = abToChern
  toSchur    = abToSchur

instance ChernBase Chern where
  chernTag  _ = ChernClass
  chernTag1 _ = ChernClass
  chernTag2 _ = ChernClass
  chernTag3 _ = ChernClass
  fromAB     = abToChern
  fromChern  = id
  fromSchur  = schurToChern
  toAB       = chernToAB
  toChern    = id
  toSchur    = chernToSchur

-- * Helper functions for constructing and specializing uniform things

-- | Constructing uniform things
select0 :: (AB, Chern) -> (ChernBase base => base)
select0 what = let final = select0' what (chernTag final) in final

select1 :: (f AB, f Chern) -> (ChernBase base => f base)
select1 what = let final = select1' what (chernTag1 final) in final

select2 :: (f (g AB), f (g Chern)) -> (ChernBase base => f (g base))
select2 what = let final = select2' what (chernTag2 final) in final

select3 :: (f (g (h AB)), f (g (h Chern))) -> (ChernBase base => f (g (h base)))
select3 what = let final = select3' what (chernTag3 final) in final

-- | Constructing unifom things using a tag
select0' :: (AB, Chern) -> (ChernBase base => Sing base -> base)
select0' (ab,ch) = \sing -> case sing of { ChernRoot -> ab ; ChernClass -> ch }

select1' :: (f AB, f Chern) -> (ChernBase base => Sing base -> f base)
select1' (ab,ch) = \sing -> case sing of { ChernRoot -> ab ; ChernClass -> ch }

select2' :: (f (g AB), f (g Chern)) -> (ChernBase base => Sing base -> f (g base))
select2' (ab,ch) = \sing -> case sing of { ChernRoot -> ab ; ChernClass -> ch }

select3' :: (f (g (h AB)), f (g (h Chern))) -> (ChernBase base => Sing base -> f (g (h base)))
select3' (ab,ch) = \sing -> case sing of { ChernRoot -> ab ; ChernClass -> ch }

-- | Specializing uniform things
spec0' :: ChernBase base => Sing base -> (forall b. ChernBase b => b) -> base
spec0' _ x = x

spec1' :: ChernBase base => Sing base -> (forall b. ChernBase b => f b) -> f base
spec1' _ x = x

spec2' :: ChernBase base => Sing base -> (forall b. ChernBase b => f (g b)) -> f (g base)
spec2' _ x = x

spec3' :: ChernBase base => Sing base -> (forall b. ChernBase b => f (g (h b))) -> f (g (h base))
spec3' _ x = x

proxyOf :: a -> Proxy a
proxyOf _ = Proxy

proxyOf1 :: f a -> Proxy a
proxyOf1 _ = Proxy

proxyOf2 :: g (f a) -> Proxy a
proxyOf2 _ = Proxy


instance Monoid AB where
  mempty = AB 0 0 
  (AB a1 b1) `mappend` (AB a2 b2) = AB (a1+a2) (b1+b2)

instance Monoid Chern where
  mempty = Chern 0 0 
  (Chern e1 f1) `mappend` (Chern e2 f2) = Chern (e1+e2) (f1+f2)

instance Monoid Schur where
  mempty  = Schur 0 0
  mappend = error "Schur/mappend: not a monoid"


instance Pretty AB where
  pretty ab = case ab of
    AB 0 0 -> "" 
    AB e 0 -> showVarPower "a" e
    AB 0 f -> showVarPower "b" f
    AB e f -> showVarPower "a" e ++ "*" ++ showVarPower "b" f
instance Pretty Chern where
  pretty (Chern 0 0) = ""
  pretty (Chern e 0) = showVarPower "c1" e
  pretty (Chern 0 f) = showVarPower "c2" f
  pretty (Chern e f) = showVarPower "c1" e ++ "*" ++ showVarPower "c2" f

instance Pretty Schur where
  pretty (Schur a b) 
    | b == 0     = "s[" ++ show a ++ "]"
    | otherwise  = "s[" ++ show a ++ "," ++ show b ++ "]"

-- * Grading

class Graded a where
  grade :: a -> Int

instance Graded AB    where grade (AB    a b) = a + b
instance Graded Chern where grade (Chern e f) = e + 2*f
instance Graded Schur where grade (Schur i j) = i + j

filterGrade :: (Ord b, Graded b) => Int -> ZMod b -> ZMod b
filterGrade g = ZMod.onFreeMod filt where
  filt = Map.filterWithKey $ \x _ -> (grade x == g)

-- * Conversions

chernToAB :: ZMod Chern -> ZMod AB 
chernToAB = ZMod.flatMap expandToAlphaBeta_1 where

  -- | c1^k * c2^n = (alpha+beta)^k * (alpha*beta)^n
  expandToAlphaBeta_1 :: Chern -> ZMod AB 
  expandToAlphaBeta_1 (Chern k n) = ZMod.fromList [ (AB (n+i) (n+k-i) , binomial k i) | i<-[0..k] ]


-- | Converts a symmetric polynomial in the AB base (Chern roots) 
-- to the Chern base (elementary symmetric polynomials or Chern classes)
abToChern :: ZMod AB -> ZMod Chern
abToChern ab = case symmetricReduction ab of
  Right c -> c
  Left _  -> error "abToChern: input was not symmetric"

-- | @Left@ means there is a non-symmetric remainder; @Right@ means
-- that input was symmetric.
symmetricReduction :: ZMod AB -> Either (ZMod Chern, ZMod AB) (ZMod Chern)
symmetricReduction = go [] where

  go sofar zmod = case ZMod.findMaxTerm zmod of
    Nothing          -> Right q
    Just (AB n m, k) -> if n < m
      then Left (q,zmod)
      else go ((ch,k):sofar) (zmod - this) where
        ch   = Chern (n-m) m
        this = ZMod.scale k $ expandToAlphaBeta_1 ch
      q = ZMod.fromList sofar

  -- | c1^k * c2^n = (alpha+beta)^k * (alpha*beta)^n
  expandToAlphaBeta_1 :: Chern -> ZMod AB 
  expandToAlphaBeta_1 (Chern k n) = ZMod.fromList [ (AB (n+i) (n+k-i) , binomial k i) | i<-[0..k] ]

-- | Convert Schur to Chern roots
schurToAB :: ZMod Schur -> ZMod AB
schurToAB = ZMod.flatMap schurExpandAB_1 where

  schurExpandAB_1 :: Schur -> ZMod AB
  schurExpandAB_1 (Schur a b)
    | b > a     = error "schurExpandAB"
    | b < 0     = error "schurExpandAB"
    | otherwise = ZMod.fromList [ ( AB (a-j) (b+j) , 1 ) | j <- [0..a-b] ]

    schurab[i_, j_] := 
     Expand[Factor[ Det[{{a^(i + 1), b^(i + 1)}, {a^j, b^j}}]] / 
       Det[{{a, b}, {1, 1}}] ]


-- | Convert Schur to Chern classes (elementary symmetric polynomials)
schurToChern :: ZMod Schur -> ZMod Chern
schurToChern = ZMod.flatMap schurExpandChern_1 where

  schurExpandChern_1 :: Schur -> ZMod Chern
  schurExpandChern_1 (Schur a b) 
    | b > a     = error "schurExpandChern_1"
    | b < 0     = error "schurExpandChern_1"
    | otherwise = ZMod.fromList [ ( Chern (a-b-2*j) (b+j) , paritySignValue j * binomial (a-b-j) j ) | j <- [0..(div (a-b) 2)] ]

  --  schurcd[i_, j_] := SymmetricReduction[schurab[i, j], {a, b}, {c1, c2}][[1]]


chernToSchur :: ZMod Chern -> ZMod Schur
chernToSchur = ZMod.flatMap chernExpandSchur_1 where

  chernExpandSchur_1 :: Chern -> ZMod Schur
  chernExpandSchur_1 (Chern e f)
    | e < 0 || f < 0 = error "chernExpandSchur"
    | otherwise      = ZMod.fromList [ ( Schur (e+f-i) (f+i) , catalanTriangle (e-i) i) | i<-[0..(div e 2)] ]


abToSchur :: ZMod AB -> ZMod Schur
abToSchur = chernToSchur . abToChern

chernToSchurNaive :: ZMod Chern -> ZMod Schur
chernToSchurNaive = ZMod.fromList . go where

  go zmod = case ZMod.findMaxTerm zmod of 
    Nothing             ->  []
    Just (Chern a b, k) -> ( s , k ) : go (zmod - this) where
      this = ZMod.scale k $ schurExpandChern_1 s
      s    = Schur (a+b) b

  schurExpandChern_1 :: Schur -> ZMod Chern
  schurExpandChern_1 (Schur a b) 
    | b > a     = error "schurExpandChern_1"
    | b < 0     = error "schurExpandChern_1"
    | otherwise = ZMod.fromList [ ( Chern (a-b-2*j) (b+j) , paritySignValue j * binomial (a-b-j) j ) | j <- [0..(div (a-b) 2)] ]

-- * random polynomials for testing

randomChernMonom :: IO Chern
randomChernMonom = do
  a <- randomRIO (0,30)
  b <- randomRIO (0,15)
  return (Chern a b)

randomSchurMonom :: IO Schur
randomSchurMonom = do
  a <- randomRIO (0,30)
  b <- randomRIO (0,30)
  return (Schur (a+b) b)

withRandomCoeff :: IO a -> IO (a,Integer)
withRandomCoeff rnd = do
  k <- randomRIO (-100,100)
  x <- rnd
  return (x,k)

randomChernPoly :: IO (ZMod Chern)   
randomChernPoly = do
  n <- randomRIO (0,100)
  ZMod.fromList <$> replicateM n (withRandomCoeff randomChernMonom)

randomSchurPoly :: IO (ZMod Schur)   
randomSchurPoly = do
  n <- randomRIO (0,100)
  ZMod.fromList <$> replicateM n (withRandomCoeff randomSchurMonom)
