```
-- | Partitions. Partitions are nonincreasing sequences of positive integers.
--
--   Donald E. Knuth: The Art of Computer Programming, vol 4, pre-fascicle 3B.

module Math.Combinat.Partitions
( -- * Type and basic stuff
Partition
, toPartition
, toPartitionUnsafe
, mkPartition
, isPartition
, fromPartition
, height
, width
, heightWidth
, weight
, dualPartition
, _dualPartition
, elements
, _elements
, countAutomorphisms
, _countAutomorphisms
-- * Generation
, partitions'
, _partitions'
, countPartitions'
, partitions
, _partitions
, countPartitions
, allPartitions'
, allPartitions
, countAllPartitions'
, countAllPartitions
-- * Paritions of multisets, vector partitions
, partitionMultiset
, IntVector
, vectorPartitions
, _vectorPartitions
, fasc3B_algorithm_M
)
where

import Data.List
import Data.Array.Unboxed

import Math.Combinat.Helper
import Math.Combinat.Numbers (factorial,binomial,multinomial)

--------------------------------------------------------------------------------

-- | The additional invariant enforced here is that partitions
--   are monotone decreasing sequences of positive integers.
newtype Partition = Partition [Int] deriving (Eq,Ord,Show,Read)

-- | Sorts the input, and cuts the nonpositive elements.
mkPartition :: [Int] -> Partition
mkPartition xs = Partition \$ sortBy (reverseCompare) \$ filter (>0) xs

-- | Assumes that the input is decreasing.
toPartitionUnsafe :: [Int] -> Partition
toPartitionUnsafe = Partition

-- | Checks whether the input is a partition. See the note at 'isPartition'!
toPartition :: [Int] -> Partition
toPartition xs = if isPartition xs
then toPartitionUnsafe xs
else error "toPartition: not a partition"

-- | Note: we only check that the sequence is ordered, but we /do not/ check for
-- negative elements. This can be useful when working with symmetric functions.
-- It may also change in the future...
isPartition :: [Int] -> Bool
isPartition []  = True
isPartition [_] = True
isPartition (x:xs@(y:_)) = (x >= y) && isPartition xs

fromPartition :: Partition -> [Int]
fromPartition (Partition part) = part

-- | The first element of the sequence.
height :: Partition -> Int
height (Partition part) = case part of
(p:_) -> p
[] -> 0

-- | The length of the sequence.
width :: Partition -> Int
width (Partition part) = length part

heightWidth :: Partition -> (Int,Int)
heightWidth part = (height part, width part)

-- | The weight of the partition
--   (that is, the sum of the corresponding sequence).
weight :: Partition -> Int
weight (Partition part) = sum part

-- | The dual (or conjugate) partition.
dualPartition :: Partition -> Partition
dualPartition (Partition part) = Partition (_dualPartition part)

-- (we could be more efficient, but it hardly matters)
_dualPartition :: [Int] -> [Int]
_dualPartition [] = []
_dualPartition xs@(k:_) = [ length \$ filter (>=i) xs | i <- [1..k] ]

-- | Example:
--
-- > elements (toPartition [5,2,1]) ==
-- > [ (1,1), (1,2), (1,3), (1,4), (1,5)
-- > , (2,1), (2,2), (2,3), (2,4)
-- > , (3,1)
-- > ]
elements :: Partition -> [(Int,Int)]
elements (Partition part) = _elements part

_elements :: [Int] -> [(Int,Int)]
_elements shape = [ (i,j) | (i,l) <- zip [1..] shape, j<-[1..l] ]

-- | Computes the number of \"automorphisms\" of a given partition.
countAutomorphisms :: Partition -> Integer
countAutomorphisms = _countAutomorphisms . fromPartition

_countAutomorphisms :: [Int] -> Integer
_countAutomorphisms = multinomial . map length . group

---------------------------------------------------------------------------------

-- | Partitions of d, fitting into a given rectangle, as lists.
_partitions'
:: (Int,Int)     -- ^ (height,width)
-> Int           -- ^ d
-> [[Int]]
_partitions' _ 0 = [[]]
_partitions' (0,_) d = if d==0 then [[]] else []
_partitions' (_,0) d = if d==0 then [[]] else []
_partitions' (h,w) d =
[ i:xs | i <- [1..min d h] , xs <- _partitions' (i,w-1) (d-i) ]

-- | Partitions of d, fitting into a given rectangle. The order is again lexicographic.
partitions'
:: (Int,Int)     -- ^ (height,width)
-> Int           -- ^ d
-> [Partition]
partitions' hw d = map toPartitionUnsafe \$ _partitions' hw d

countPartitions' :: (Int,Int) -> Int -> Integer
countPartitions' _ 0 = 1
countPartitions' (0,_) d = if d==0 then 1 else 0
countPartitions' (_,0) d = if d==0 then 1 else 0
countPartitions' (h,w) d = sum
[ countPartitions' (i,w-1) (d-i) | i <- [1..min d h] ]

-- | Partitions of d, as lists
_partitions :: Int -> [[Int]]
_partitions d = _partitions' (d,d) d

-- | Partitions of d.
partitions :: Int -> [Partition]
partitions d = partitions' (d,d) d

countPartitions :: Int -> Integer
countPartitions d = countPartitions' (d,d) d

-- | All partitions fitting into a given rectangle.
allPartitions'
:: (Int,Int)        -- ^ (height,width)
-> [[Partition]]
allPartitions' (h,w) = [ partitions' (h,w) i | i <- [0..d] ] where d = h*w

-- | All partitions up to a given degree.
allPartitions :: Int -> [[Partition]]
allPartitions d = [ partitions i | i <- [0..d] ]

-- | # = \\binom { h+w } { h }
countAllPartitions' :: (Int,Int) -> Integer
countAllPartitions' (h,w) =
binomial (h+w) (min h w)
--sum [ countPartitions' (h,w) i | i <- [0..d] ] where d = h*w

countAllPartitions :: Int -> Integer
countAllPartitions d = sum [ countPartitions i | i <- [0..d] ]

--------------------------------------------------------------------------------

-- | Partitions of a multiset.
partitionMultiset :: (Eq a, Ord a) => [a] -> [[[a]]]
partitionMultiset xs = parts where
parts = (map . map) (f . elems) temp
f ns = concat (zipWith replicate ns zs)
temp = fasc3B_algorithm_M counts
counts = map length ys
ys = group (sort xs)
zs = map head ys

-- | Integer vectors. The indexing starts from 1.
type IntVector = UArray Int Int

-- | Vector partitions. Basically a synonym for 'fasc3B_algorithm_M'.
vectorPartitions :: IntVector -> [[IntVector]]
vectorPartitions = fasc3B_algorithm_M . elems

_vectorPartitions :: [Int] -> [[[Int]]]
_vectorPartitions = map (map elems) . fasc3B_algorithm_M

-- | Generates all vector partitions
--   (\"algorithm M\" in Knuth).
--   The order is decreasing lexicographic.
fasc3B_algorithm_M :: [Int] -> [[IntVector]]
{- note to self: Knuth's descriptions of algorithms are still totally unreadable -}
fasc3B_algorithm_M xs = worker [start] where

-- n = sum xs
m = length xs

start = [ (j,x,x) | (j,x) <- zip [1..] xs ]

worker stack@(last:_) =
case decrease stack' of
Nothing -> [visited]
Just stack'' -> visited : worker stack''
where
stack'  = subtract_rec stack
visited = map to_vector stack'

decrease (last:rest) =
case span (\(_,_,v) -> v==0) (reverse last) of
( _ , [(_,_,1)] ) -> case rest of
[] -> Nothing
_  -> decrease rest
( second , (c,u,v):first ) -> Just (modified:rest) where
modified =
reverse first ++
(c,u,v-1) :
[ (c,u,u) | (c,u,_) <- reverse second ]
_ -> error "should not happen"

to_vector cuvs =
accumArray (flip const) 0 (1,m)
[ (c,v) | (c,_,v) <- cuvs ]

subtract_rec all@(last:_) =
case subtract last of
[]  -> all
new -> subtract_rec (new:all)

subtract [] = []
subtract full@((c,u,v):rest) =
if w >= v
then (c,w,v) : subtract   rest
else           subtract_b full
where w = u - v

subtract_b [] = []
subtract_b ((c,u,v):rest) =
if w /= 0
then (c,w,w) : subtract_b rest
else           subtract_b rest
where w = u - v

--------------------------------------------------------------------------------
```