```
-- TODO: better name?

-- | This module contains a function to generate (equivalence classes of)
-- triangular tableaux of size /k/, strictly increasing to the right and
-- to the bottom. For example
--
-- >  1
-- >  2  4
-- >  3  5  8
-- >  6  7  9  10
--
-- is such a tableau of size 4.
-- The numbers filling a tableau always consist of an interval @[1..c]@;
-- @c@ is called the /content/ of the tableaux. There is a unique tableau
-- of minimal content @2k-1@:
--
-- >  1
-- >  2  3
-- >  3  4  5
-- >  4  5  6  7
--
-- Let us call the tableaux with maximal content (that is, @m = binomial (k+1) 2@)
-- /standard/. The number of standard tableaux are
--
-- > 1, 1, 2, 12, 286, 33592, 23178480, ...
--
-- OEIS:A003121, \"Strict sense ballot numbers\",
-- <http://www.research.att.com/~njas/sequences/A003121>.
--
-- See
-- R. M. Thrall, A combinatorial problem, Michigan Math. J. 1, (1952), 81-88.
--
-- The number of tableaux with content @c=m-d@ are
--
-- >  d=  |     0      1      2      3    ...
-- > -----+----------------------------------------------
-- >  k=2 |     1
-- >  k=3 |     2      1
-- >  k=4 |    12     18      8      1
-- >  k=5 |   286    858   1001    572    165     22     1
-- >  k=6 | 33592 167960 361114 436696 326196 155584 47320 8892 962 52 1
--
-- We call these \"Kostka tableaux\" (in the lack of a better name), since
-- they are in bijection with the simplicial cones in a canonical simplicial
-- decompositions of the Gelfand-Tsetlin cones (the content corresponds
-- to the dimension), which encode the combinatorics of Kostka numbers.
--

module Math.Combinat.Tableaux.Kostka
(
Tableau
, Tri(..)
, TriangularArray
, fromTriangularArray
, triangularArrayUnsafe
, kostkaTableaux
, _kostkaTableaux
, countKostkaTableaux
, kostkaContent
, _kostkaContent
)
where

--------------------------------------------------------------------------------

import Data.Ix
import Data.Ord
import Data.List

import Data.Array.IArray
import Data.Array.Unboxed
import Data.Array.ST

import Math.Combinat.Tableaux (Tableau)
import Math.Combinat.Helper

--------------------------------------------------------------------------------

-- | Triangular arrays
type TriangularArray a = Array Tri a

-- | Set of @(i,j)@ pairs with @i>=j>=1@.
newtype Tri = Tri { unTri :: (Int,Int) } deriving (Eq,Ord,Show)

binom2 :: Int -> Int
binom2 n = (n*(n-1)) `div` 2

index' :: Tri -> Int
index' (Tri (i,j)) = binom2 i + j - 1

-- it should be (1+8*m),
-- the 2 is a hack to be safe with the floating point stuff
deIndex' :: Int -> Tri
deIndex' m = Tri ( i+1 , m - binom2 (i+1) + 1 ) where
i = ( (floor.sqrt.(fromIntegral::Int->Double)) (2+8*m) - 1 ) `div` 2

instance Ix Tri where
index   (a,b) x = index' x - index' a
inRange (a,b) x = (u<=j && j<=v) where
u = index' a
v = index' b
j = index' x
range     (a,b) = map deIndex' [ index' a .. index' b ]
rangeSize (a,b) = index' b - index' a + 1

{-# SPECIALIZE triangularArrayUnsafe :: Tableau Int -> TriangularArray Int #-}
triangularArrayUnsafe :: Tableau a -> TriangularArray a
triangularArrayUnsafe tableau = listArray (Tri (1,1),Tri (k,k)) (concat tableau)
where k = length tableau

{-# SPECIALIZE fromTriangularArray :: TriangularArray Int -> Tableau Int #-}
fromTriangularArray :: TriangularArray a -> Tableau a
fromTriangularArray arr = (map.map) snd \$ groupBy (equating f) \$ assocs arr
where f = fst . unTri . fst

--------------------------------------------------------------------------------

-- "fractional fillings"
data Hole = Hole Int Int deriving (Eq,Ord,Show)

type ReverseTableau      = [[Int ]]
type ReverseHoleTableau  = [[Hole]]

toHole :: Int -> Hole
toHole k = Hole k 0

nextHole :: Hole -> Hole
nextHole (Hole k l) = Hole k (l+1)

{-# SPECIALIZE reverseTableau :: [[Int]] -> [[Int]] #-}
reverseTableau :: [[a]] -> [[a]]
reverseTableau = reverse . map reverse

--------------------------------------------------------------------------------

kostkaContent :: TriangularArray Int -> Int
kostkaContent arr = arr ! (snd (bounds arr))

_kostkaContent :: Tableau Int -> Int
_kostkaContent = last . last

normalize :: ReverseHoleTableau -> TriangularArray Int
normalize = snd . normalize'

-- returns ( content , tableau )
normalize' :: ReverseHoleTableau -> ( Int , TriangularArray Int )
normalize' holes = ( c , array (Tri (1,1), Tri (k,k)) xys ) where
k = length holes
c = length sorted
xys = concat \$ zipWith hs [1..] sorted
hs a xs     = map (h a) xs
h  a (ij,_) = (Tri ij , a)
sorted = groupSortBy snd (concat withPos)
withPos = zipWith f [1..] (reverseTableau holes)
f i xs = zipWith (g i) [1..] xs
g i j hole = ((i,j),hole)

--------------------------------------------------------------------------------

startHole :: [Hole] -> [Int] -> Hole
startHole (t:ts) (p:ps) = max t (toHole p)
startHole (t:ts) []     = t
startHole []     (p:ps) = toHole p
startHole []     []     = error "startHole"

-- c is the "content" of the small tableau
enumHoles :: Int -> Hole -> [Hole]
enumHoles c start@(Hole k l)
= nextHole start
: [ Hole i 0 | i <- [k+1..c] ] ++ [ Hole i 1 | i <- [k+1..c] ]

helper :: Int -> [Int] -> [Hole] -> [[Hole]]
helper c [] this = [[]]
helper c prev@(p:ps) this =
[ t:rest | t <- enumHoles c (startHole this prev), rest <- helper c ps (t:this) ]

newLines' :: Int -> [Int] -> [[Hole]]
newLines' c lastReversed = helper c last []
where
last = reverse (top : lastReversed)

newLines :: [Int] -> [[Hole]]
newLines lastReversed = newLines' (head lastReversed) lastReversed

-- | Generates all tableaux of size @k@. Effective for @k<=6@.
kostkaTableaux :: Int -> [TriangularArray Int]
kostkaTableaux 0 = [ triangularArrayUnsafe [] ]
kostkaTableaux 1 = [ triangularArrayUnsafe [[1]] ]
kostkaTableaux k = map normalize \$ concatMap f smalls where
smalls :: [ [[Int]] ]
smalls = map (reverseTableau . fromTriangularArray) \$ kostkaTableaux (k-1)
f :: [[Int]] -> [ [[Hole]] ]
f small = map (:smallhole) \$ map reverse \$ newLines (head small) where
smallhole = map (map toHole) small

_kostkaTableaux :: Int -> [Tableau Int]
_kostkaTableaux k = map fromTriangularArray \$ kostkaTableaux k

--------------------------------------------------------------------------------

countKostkaTableaux :: Int -> [Int]
countKostkaTableaux = elems . sizes'

sizes' :: Int -> UArray Int Int
sizes' k =
runSTUArray \$ do
let (a,b) = ( 2*k-1 , binom2 (k+1) )
ar <- newArray (a,b) 0 :: ST s (STUArray s Int Int)
mapM_ (worker ar) \$ kostkaTableaux k
return ar
where
worker :: STUArray s Int Int -> TriangularArray Int -> ST s ()
worker ar t = do
let c = kostkaContent t