```
-- | Gelfand-Tsetlin patterns and Kostka numbers.
--
-- Gelfand-Tsetlin patterns (or tableaux) are triangular arrays like
--
-- > [ 3 ]
-- > [ 3 , 2 ]
-- > [ 3 , 1 , 0 ]
-- > [ 2 , 0 , 0 , 0 ]
--
-- with both rows and columns non-increasing non-negative integers.
-- Note: these are in bijection with the semi-standard Young tableaux.
--
-- If we add the further restriction that
-- the top diagonal reads @lambda@,
-- and the diagonal sums are partial sums of @mu@, where @lambda@ and @mu@ are two
-- partitions (in this case @lambda=[3,2]@ and @mu=[2,1,1,1]@),
-- then the number of the resulting patterns
-- or tableaux is the Kostka number @K(lambda,mu)@.
-- Actually @mu@ doesn't even need to the be non-increasing.
--

{-# LANGUAGE BangPatterns, ScopedTypeVariables #-}
module Math.Combinat.Tableaux.GelfandTsetlin where

--------------------------------------------------------------------------------

import Data.List
import Data.Maybe
import Data.Monoid
import Data.Ord

import Control.Monad
import Control.Monad.Trans.State

import Data.Map (Map)
import qualified Data.Map as Map

import Math.Combinat.Partitions.Integer
import Math.Combinat.Tableaux
import Math.Combinat.Helper
import Math.Combinat.ASCII

--------------------------------------------------------------------------------
-- * Kostka numbers

-- | Kostka numbers (via counting Gelfand-Tsetlin patterns). See for example <http://en.wikipedia.org/wiki/Kostka_number>
--
-- @K(lambda,mu)==0@ unless @lambda@ dominates @mu@:
--
-- > [ mu | mu <- partitions (weight lam) , kostkaNumber lam mu > 0 ] == dominatedPartitions lam
--
kostkaNumber :: Partition -> Partition -> Int
kostkaNumber = countKostkaGelfandTsetlinPatterns

-- | Very naive (and slow) implementation of Kostka numbers, for reference.
kostkaNumberReferenceNaive :: Partition -> Partition -> Int
kostkaNumberReferenceNaive plambda pmu@(Partition mu) = length stuff where
stuff  = [ (1::Int) | t <- semiStandardYoungTableaux k plambda , cond t ]
k      = length mu
cond t = [ (head xs, length xs) | xs <- group (sort \$ concat t) ] == zip [1..] mu

--------------------------------------------------------------------------------

-- | Lists all (positive) Kostka numbers @K(lambda,mu)@ with the given @lambda@:
--
-- > kostkaNumbersWithGivenLambda lambda == Map.fromList [ (mu , kostkaNumber lambda mu) | mu <- dominatedPartitions lambda ]
--
-- It's much faster than computing the individual Kostka numbers, but not as fast
-- as it could be.
--
{-# SPECIALIZE kostkaNumbersWithGivenLambda :: Partition -> Map Partition Int     #-}
{-# SPECIALIZE kostkaNumbersWithGivenLambda :: Partition -> Map Partition Integer #-}
kostkaNumbersWithGivenLambda :: forall coeff. Num coeff => Partition -> Map Partition coeff
kostkaNumbersWithGivenLambda plambda@(Partition lam) = evalState (worker lam) Map.empty where

worker :: [Int] -> State (Map Partition (Map Partition coeff)) (Map Partition coeff)
worker unlam = case unlam of
[] -> return \$ Map.singleton (Partition []) 1
_  -> do
cache <- get
case Map.lookup (Partition unlam) cache of
Just sol -> return sol
Nothing  -> do
let s = foldl' (+) 0 unlam
subsols <- forM (prevLambdas0 unlam) \$ \p -> do
sub <- worker p
let t = s - foldl' (+) 0 p
f (Partition xs , c) = case xs of
(y:_) -> if t >= y then Just (Partition (t:xs) , c) else Nothing
[]    -> if t >  0 then Just (Partition [t]    , c) else Nothing
if t > 0
then return \$ Map.fromList \$ mapMaybe f \$ Map.toList sub
else return \$ Map.empty

let sol = Map.unionsWith (+) subsols
put \$! (Map.insert (Partition unlam) sol cache)
return sol

-- needs decreasing sequence
prevLambdas0 :: [Int] -> [[Int]]
prevLambdas0 (l:ls) = go l ls where
go b [a]    = [ [x]   | x <- [a..b] ] ++ [ [x,y] | x <- [a..b] , y<-[1..a] ]
go b (a:as) = [ x:xs  | x <- [a..b] , xs <- go a as ]
go b []     = [] : [ [j] | j <- [1..b] ]
prevLambdas0 []  = []

-- | Lists all (positive) Kostka numbers @K(lambda,mu)@ with the given @mu@:
--
-- > kostkaNumbersWithGivenMu mu == Map.fromList [ (lambda , kostkaNumber lambda mu) | lambda <- dominatingPartitions mu ]
--
-- This function uses the iterated Pieri rule, and is relatively fast.
--
kostkaNumbersWithGivenMu :: Partition -> Map Partition Int
kostkaNumbersWithGivenMu (Partition mu) = iteratedPieriRule (reverse mu)

--------------------------------------------------------------------------------
-- * Gelfand-Tsetlin patterns

-- | A Gelfand-Tstetlin tableau
type GT = [[Int]]

asciiGT :: GT -> ASCII
asciiGT gt = tabulate (HRight,VTop) (HSepSpaces 1, VSepEmpty)
\$ (map . map) asciiShow
\$ gt

kostkaGelfandTsetlinPatterns :: Partition -> Partition -> [GT]
kostkaGelfandTsetlinPatterns lambda (Partition mu) = kostkaGelfandTsetlinPatterns' lambda mu

-- | Generates all Kostka-Gelfand-Tsetlin tableau, that is, triangular arrays like
--
-- > [ 3 ]
-- > [ 3 , 2 ]
-- > [ 3 , 1 , 0 ]
-- > [ 2 , 0 , 0 , 0 ]
--
-- with both rows and column non-increasing such that
-- the top diagonal read lambda (in this case @lambda=[3,2]@) and the diagonal sums
-- are partial sums of mu (in this case @mu=[2,1,1,1]@)
--
-- The number of such GT tableaux is the Kostka
-- number K(lambda,mu).
--
kostkaGelfandTsetlinPatterns' :: Partition -> [Int] -> [GT]
kostkaGelfandTsetlinPatterns' plam@(Partition lambda0) mu0
| minimum mu0 < 0                       = []
| wlam == 0                             = if wmu == 0 then [ [] ] else []
| wmu  == wlam && plam `dominates` pmu  = list
| otherwise                             = []
where

pmu = mkPartition mu0

nlam = length lambda0
nmu  = length mu0

n = max nlam nmu

lambda = lambda0 ++ replicate (n - nlam) 0
mu     = mu0     ++ replicate (n - nmu ) 0

revlam = reverse lambda

wmu  = sum' mu
wlam = sum' lambda

list = worker
revlam
(scanl1 (+) mu)
(replicate (n-1) 0)
(replicate (n  ) 0)
[]

worker
:: [Int]       -- lambda_i in reverse order
-> [Int]       -- partial sums of mu
-> [Int]       -- sums of the tails of previous rows
-> [Int]       -- last row
-> [[Int]]     -- the lower part of GT tableau we accumulated so far (this is not needed if we only want to count)
-> [GT]

worker (rl:rls) (smu:smus) (a:acc) (lastx0:lastrowt) table = stuff
where
x0 = smu - a
stuff = concat
[ worker rls smus (zipWith (+) acc (tail row)) (init row) (row:table)
| row <- boundedNonIncrSeqs' x0 (map (max rl) (max lastx0 x0 : lastrowt)) lambda
]
worker [rl] _ _ _ table = [ [rl]:table ]
worker []   _ _ _ _     = [ []         ]

boundedNonIncrSeqs' :: Int -> [Int] -> [Int] -> [[Int]]
boundedNonIncrSeqs' = go where
go h0 (a:as) (b:bs) = [ h:hs | h <- [(max 0 a)..(min h0 b)] , hs <- go h as bs ]
go _  []     _      = [[]]
go _  _      []     = [[]]

--------------------------------------------------------------------------------

-- | This returns the corresponding Kostka number:
--
-- > countKostkaGelfandTsetlinPatterns lambda mu == length (kostkaGelfandTsetlinPatterns lambda mu)
--
countKostkaGelfandTsetlinPatterns :: Partition -> Partition -> Int
countKostkaGelfandTsetlinPatterns plam@(Partition lambda0) pmu@(Partition mu0)
| wlam == 0                             = if wmu == 0 then 1 else 0
| wmu  == wlam && plam `dominates` pmu  = cnt
| otherwise                             = 0
where

nlam = length lambda0
nmu  = length mu0

n = max nlam nmu

lambda = lambda0 ++ replicate (n - nlam) 0
mu     = mu0     ++ replicate (n - nmu ) 0

revlam = reverse lambda

wmu  = sum' mu
wlam = sum' lambda

cnt = worker
revlam
(scanl1 (+) mu)
(replicate (n-1) 0)
(replicate (n  ) 0)

worker
:: [Int]       -- lambda_i in reverse order
-> [Int]       -- partial sums of mu
-> [Int]       -- sums of the tails of previous rows
-> [Int]       -- last row
-> Int

worker (rl:rls) (smu:smus) (a:acc) (lastx0:lastrowt) = stuff
where
x0 = smu - a
stuff = sum'
[ worker rls smus (zipWith (+) acc (tail row)) (init row)
| row <- boundedNonIncrSeqs' x0 (map (max rl) (max lastx0 x0 : lastrowt)) lambda
]
worker [rl] _ _ _ = 1
worker []   _ _ _ = 1

boundedNonIncrSeqs' :: Int -> [Int] -> [Int] -> [[Int]]
boundedNonIncrSeqs' = go where
go h0 (a:as) (b:bs) = [ h:hs | h <- [(max 0 a)..(min h0 b)] , hs <- go h as bs ]
go _  []     _      = [[]]
go _  _      []     = [[]]

--------------------------------------------------------------------------------

{-

-- | All non-increasing sentences between a lower and an upper bound
boundedNonIncrSeqs :: [Int] -> [Int] -> [[Int]]
boundedNonIncrSeqs as bs = case bs of
(h0:_) -> boundedNonIncrSeqs' h0 as bs
[]     -> [[]]

-- | All non-increasing sentences between a lower and an upper bound, and also less-or-equal than the given number
boundedNonIncrSeqs' :: Int -> [Int] -> [Int] -> [[Int]]
boundedNonIncrSeqs' = go where
go h0 (a:as) (b:bs) = [ h:hs | h <- [(max 0 a)..(min h0 b)] , hs <- go h as bs ]
go _  []     _      = [[]]
go _  _      []     = [[]]

-- | All non-decreasing sentences between a lower and an upper bound
boundedNonDecrSeqs :: [Int] -> [Int] -> [[Int]]
boundedNonDecrSeqs = boundedNonDecrSeqs' 0

-- | All non-decreasing sentences between a lower and an upper bound, and also greator-or-equal then the given number
boundedNonDecrSeqs' :: Int -> [Int] -> [Int] -> [[Int]]
boundedNonDecrSeqs' h0 = go (max 0 h0) where
go h0 (a:as) (b:bs) = [ h:hs | h <- [(max h0 a)..b] , hs <- go h as bs ]
go _  []     _      = [[]]
go _  _      []     = [[]]

-}

--------------------------------------------------------------------------------
-- * The iterated Pieri rule

-- | Computes the Schur expansion of @h[n1]*h[n2]*h[n3]*...*h[nk]@ via iterating the Pieri rule.
-- Note: the coefficients are actually the Kostka numbers; the following is true:
--
-- > Map.toList (iteratedPieriRule (fromPartition mu))  ==  [ (lam, kostkaNumber lam mu) | lam <- dominatingPartitions mu ]
--
-- This should be faster than individually computing all these Kostka numbers.
--
iteratedPieriRule :: Num coeff => [Int] -> Map Partition coeff
iteratedPieriRule = iteratedPieriRule' (Partition [])

-- | Iterating the Pieri rule, we can compute the Schur expansion of
-- @h[lambda]*h[n1]*h[n2]*h[n3]*...*h[nk]@
iteratedPieriRule' :: Num coeff => Partition -> [Int] -> Map Partition coeff
iteratedPieriRule' plambda ns = iteratedPieriRule'' (plambda,1) ns

{-# SPECIALIZE iteratedPieriRule'' :: (Partition,Int    ) -> [Int] -> Map Partition Int     #-}
{-# SPECIALIZE iteratedPieriRule'' :: (Partition,Integer) -> [Int] -> Map Partition Integer #-}
iteratedPieriRule'' :: Num coeff => (Partition,coeff) -> [Int] -> Map Partition coeff
iteratedPieriRule'' (plambda,coeff0) ns = worker (Map.singleton plambda coeff0) ns where
worker old []     = old
worker old (n:ns) = worker new ns where
stuff = [ (coeff, pieriRule lam n) | (lam,coeff) <- Map.toList old ]
new   = foldl' f Map.empty stuff
f t0 (c,ps) = foldl' (\t p -> Map.insertWith (+) p c t) t0 ps

--------------------------------------------------------------------------------

-- | Computes the Schur expansion of @e[n1]*e[n2]*e[n3]*...*e[nk]@ via iterating the Pieri rule.
-- Note: the coefficients are actually the Kostka numbers; the following is true:
--
-- > Map.toList (iteratedDualPieriRule (fromPartition mu))  ==
-- >   [ (dualPartition lam, kostkaNumber lam mu) | lam <- dominatingPartitions mu ]
--
-- This should be faster than individually computing all these Kostka numbers.
-- It is a tiny bit slower than 'iteratedPieriRule'.
--
iteratedDualPieriRule :: Num coeff => [Int] -> Map Partition coeff
iteratedDualPieriRule = iteratedDualPieriRule' (Partition [])

-- | Iterating the Pieri rule, we can compute the Schur expansion of
-- @e[lambda]*e[n1]*e[n2]*e[n3]*...*e[nk]@
iteratedDualPieriRule' :: Num coeff => Partition -> [Int] -> Map Partition coeff
iteratedDualPieriRule' plambda ns = iteratedDualPieriRule'' (plambda,1) ns

{-# SPECIALIZE iteratedDualPieriRule'' :: (Partition,Int    ) -> [Int] -> Map Partition Int     #-}
{-# SPECIALIZE iteratedDualPieriRule'' :: (Partition,Integer) -> [Int] -> Map Partition Integer #-}
iteratedDualPieriRule'' :: Num coeff => (Partition,coeff) -> [Int] -> Map Partition coeff
iteratedDualPieriRule'' (plambda,coeff0) ns = worker (Map.singleton plambda coeff0) ns where
worker old []     = old
worker old (n:ns) = worker new ns where
stuff = [ (coeff, dualPieriRule lam n) | (lam,coeff) <- Map.toList old ]
new   = foldl' f Map.empty stuff
f t0 (c,ps) = foldl' (\t p -> Map.insertWith (+) p c t) t0 ps

--------------------------------------------------------------------------------
```