{-# LANGUAGE DeriveFunctor         #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- | Efficient combinatorial algorithms over multisets, including
--   generating all permutations, partitions, subsets, cycles, and
--   other combinatorial structures based on multisets.  Note that an
--   'Eq' or 'Ord' instance on the elements is /not/ required; the
--   algorithms are careful to keep track of which things are (by
--   construction) equal to which other things, so equality testing is
--   not needed.
module Math.Combinatorics.Multiset
       ( -- * The 'Multiset' type

         Count
       , Multiset(..)
       , emptyMS, singletonMS
       , consMS, (+:)

         -- ** Conversions
       , toList
       , fromList
       , fromListEq
       , fromDistinctList
       , fromCounts
       , getCounts
       , size

         -- ** Operations
       , disjUnion
       , disjUnions

         -- * Permutations

       , permutations
       , permutationsRLE

         -- * Partitions

       , Vec
       , vPartitions
       , partitions

         -- * Submultisets

       , splits
       , kSubsets

         -- * Cycles and bracelets

       , cycles
       , bracelets
       , genFixedBracelets

         -- * Miscellaneous

       , sequenceMS

       ) where

import           Control.Arrow              (first, second, (&&&), (***))
import           Control.Monad              (forM_, when)
import           Control.Monad.Trans.Writer
import qualified Data.IntMap.Strict         as IM
import           Data.List                  (group, partition, sort)
import           Data.Maybe                 (catMaybes, fromJust)

type Count = Int

-- | A multiset is represented as a list of (element, count) pairs.
--   We maintain the invariants that the counts are always positive,
--   and no element ever appears more than once.
newtype Multiset a = MS { toCounts :: [(a, Count)] }
  deriving (Show, Functor)

-- | Construct a 'Multiset' from a list of (element, count) pairs.
--   Precondition: the counts must all be positive, and there must not
--   be any duplicate elements.
fromCounts :: [(a, Count)] -> Multiset a
fromCounts = MS

-- | Extract just the element counts from a multiset, forgetting the
--   elements.
getCounts :: Multiset a -> [Count]
getCounts = map snd . toCounts

-- | Compute the total size of a multiset.
size :: Multiset a -> Int
size = sum . getCounts

liftMS :: ([(a, Count)] -> [(b, Count)]) -> Multiset a -> Multiset b
liftMS f (MS m) = MS (f m)

-- | A multiset with no values in it.
emptyMS :: Multiset a
emptyMS = MS []

-- | Create a multiset with only a single value in it.
singletonMS :: a -> Multiset a
singletonMS a = MS [(a,1)]

-- | Add an element with multiplicity to a multiset.  Precondition:
--   the new element is distinct from all elements already in the
--   multiset.
consMS :: (a, Count) -> Multiset a -> Multiset a
consMS e@(_,c) (MS m)
  | c > 0     = MS (e:m)
  | otherwise = MS m

-- | A convenient shorthand for 'consMS'.
(+:) :: (a, Count) -> Multiset a -> Multiset a
(+:) = consMS

-- | Convert a multiset to a list.
toList :: Multiset a -> [a]
toList = expandCounts . toCounts

expandCounts :: [(a, Count)] -> [a]
expandCounts = concatMap (uncurry (flip replicate))

-- | Efficiently convert a list to a multiset, given an 'Ord' instance
--   for the elements.  This method is provided just for convenience.
--   you can also use 'fromListEq' with only an 'Eq' instance, or
--   construct 'Multiset's directly using 'fromCounts'.
fromList :: Ord a => [a] -> Multiset a
fromList = fromCounts . map (head &&& length) . group . sort

-- | Convert a list to a multiset, given an 'Eq' instance for the
--   elements.
fromListEq :: Eq a => [a] -> Multiset a
fromListEq = fromCounts . fromListEq'
  where fromListEq' []     = []
        fromListEq' (x:xs) = (x, 1 + length xEqs) : fromListEq' xNeqs
          where
            (xEqs, xNeqs) = partition (==x) xs

-- | Make a multiset with one copy of each element from a list of
--   distinct elements.
fromDistinctList :: [a] -> Multiset a
fromDistinctList = fromCounts . map (\x -> (x,1))

-- | Form the disjoint union of two multisets; i.e. we assume the two
--   multisets share no elements in common.
disjUnion :: Multiset a -> Multiset a -> Multiset a
disjUnion (MS xs) (MS ys) = MS (xs ++ ys)

-- | Form the disjoint union of a collection of multisets.  We assume
--   that the multisets all have distinct elements.
disjUnions :: [Multiset a] -> Multiset a
disjUnions = foldr disjUnion (MS [])

-- | In order to generate permutations of a multiset, we need to keep
--   track of the most recently used element in the permutation being
--   built, so that we don't use it again immediately.  The
--   'RMultiset' type (for \"restricted multiset\") records this
--   information, consisting of a multiset possibly paired with an
--   element (with multiplicity) which is also part of the multiset,
--   but should not be used at the beginning of permutations.
data RMultiset a = RMS (Maybe (a, Count)) [(a,Count)]
  deriving Show

-- | Convert a 'Multiset' to a 'RMultiset' (with no avoided element).
toRMS :: Multiset a -> RMultiset a
toRMS = RMS Nothing . toCounts

-- | Convert a 'RMultiset' to a 'Multiset'.
fromRMS :: RMultiset a -> Multiset a
fromRMS (RMS Nothing m)  = MS m
fromRMS (RMS (Just e) m) = MS (e:m)

-- | List all the distinct permutations of the elements of a
--   multiset.
--
--   For example, @permutations (fromList \"abb\") ==
--   [\"abb\",\"bba\",\"bab\"]@, whereas @Data.List.permutations
--   \"abb\" == [\"abb\",\"bab\",\"bba\",\"bba\",\"bab\",\"abb\"]@.
--   This function is equivalent to, but /much/ more efficient than,
--   @nub . Data.List.permutations@, and even works when the elements
--   have no 'Eq' instance.
--
--   Note that this is a specialized version of 'permutationsRLE',
--   where each run has been expanded via 'replicate'.
permutations :: Multiset a -> [[a]]
permutations = map expandCounts . permutationsRLE

-- | List all the distinct permutations of the elements of a multiset,
--   with each permutation run-length encoded. (Note that the
--   run-length encoding is a natural byproduct of the algorithm used,
--   not a separate postprocessing step.)
--
--   For example, @permutationsRLE [('a',1), ('b',2)] ==
--   [[('a',1),('b',2)],[('b',2),('a',1)],[('b',1),('a',1),('b',1)]]@.
--
--   (Note that although the output type is newtype-equivalent to
--   @[Multiset a]@, we don't call it that since the output may
--   violate the 'Multiset' invariant that no element should appear
--   more than once.  And indeed, morally this function does not
--   output multisets at all.)
permutationsRLE :: Multiset a -> [[(a,Count)]]
permutationsRLE (MS []) = [[]]
permutationsRLE m       = permutationsRLE' (toRMS m)

-- | List all the (run-length encoded) distinct permutations of the
--   elements of a multiset which do not start with the element to
--   avoid (if any).
permutationsRLE' :: RMultiset a -> [[(a,Count)]]

-- If only one element is left, there's only one permutation.
permutationsRLE' (RMS Nothing [(x,n)]) = [[(x,n)]]

-- Otherwise, select an element+multiplicity in all possible ways, and
-- concatenate the elements to all possible permutations of the
-- remaining multiset.
permutationsRLE' m  = [ e : p
                      | (e, m') <- selectRMS m
                      , p       <- permutationsRLE' m'
                      ]

-- | Select an element + multiplicity from a multiset in all possible
--   ways, appropriately keeping track of elements to avoid at the
--   start of permutations.
selectRMS :: RMultiset a -> [((a, Count), RMultiset a)]

-- No elements to select.
selectRMS (RMS _ [])            = []

-- Selecting from a multiset with n copies of x, avoiding e:
selectRMS (RMS e ((x,n) : ms))  =

  -- If we select all n copies of x, there are no copies of x left to avoid;
  -- stick e (if it exists) back into the remaining multiset.
  ((x,n), RMS Nothing (maybe ms (:ms) e)) :

  -- We can also select any number of copies of x from (n-1) down to 1; in each case,
  -- we avoid the remaining copies of x and put e back into the returned multiset.
  [ ( (x,k), RMS (Just (x,n-k))
                 (maybe ms (:ms) e) )
    | k <- [n-1, n-2 .. 1]
  ] ++

  -- Finally, we can recursively choose something other than x.
  map (second (consRMS (x,n))) (selectRMS (RMS e ms))

consRMS :: (a, Count) -> RMultiset a -> RMultiset a
consRMS x (RMS e m) = RMS e (x:m)


-- Some QuickCheck properties.  Of course, due to combinatorial
-- explosion these are of limited utility!
-- newtype ArbCount = ArbCount Int
--   deriving (Eq, Show, Num, Real, Enum, Ord, Integral)

-- instance Arbitrary Count where
--   arbitrary = elements (map ArbCount [1..3])

-- prop_perms_distinct :: Multiset Char ArbCount -> Bool
-- prop_perms_distinct m = length ps == length (nub ps)
--   where ps = permutations m

-- prop_perms_are_perms :: Multiset Char ArbCount -> Bool
-- prop_perms_are_perms m = all ((==l') . sort) (permutations m)
--   where l' = sort (toList m)

---------------------
-- Partitions
---------------------

-- | Element count vector.
type Vec = [Count]

-- | Componentwise comparison of count vectors.
(<|=) :: Vec -> Vec -> Bool
xs <|= ys = and $ zipWith (<=) xs ys

-- | 'vZero v' produces a zero vector of the same length as @v@.
vZero :: Vec -> Vec
vZero = map (const 0)

-- | Test for the zero vector.
vIsZero :: Vec -> Bool
vIsZero = all (==0)

-- | Do vector arithmetic componentwise.
(.+.), (.-.) :: Vec -> Vec -> Vec
(.+.) = zipWith (+)
(.-.) = zipWith (-)

-- | Multiply a count vector by a scalar.
(*.) :: Count -> Vec -> Vec
(*.) n = map (n*)

-- | 'v1 `vDiv` v2' is the largest scalar multiple of 'v2' which is
--   elementwise less than or equal to 'v1'.
vDiv :: Vec -> Vec -> Count
vDiv v1 v2 = minimum . catMaybes $ zipWith zdiv v1 v2
  where zdiv _ 0 = Nothing
        zdiv x y = Just $ x `div` y

-- | 'vInc within v' lexicographically increments 'v' with respect to
--   'within'.  For example, @vInc [2,3,5] [1,3,4] == [1,3,5]@, and
--   @vInc [2,3,5] [1,3,5] == [2,0,0]@.
vInc :: Vec -> Vec -> Vec
vInc lim v = reverse (vInc' (reverse lim) (reverse v))
  where vInc' _ []          = []
        vInc' [] (x:xs)     = x+1 : xs
        vInc' (l:ls) (x:xs) | x < l     = x+1 : xs
                            | otherwise = 0 : vInc' ls xs

-- | Generate all vector partitions, representing each partition as a
--   multiset of vectors.
--
--   This code is a slight generalization of the code published in
--
--     Brent Yorgey. \"Generating Multiset Partitions\". In: The
--     Monad.Reader, Issue 8, September 2007.
--     <http://www.haskell.org/sitewiki/images/d/dd/TMR-Issue8.pdf>
--
--   See that article for a detailed discussion of the code and how it works.
vPartitions :: Vec -> [Multiset Vec]
vPartitions v = vPart v (vZero v) where
  vPart v _ | vIsZero v = [MS []]
  vPart v vL
    | v <= vL   = []
    | otherwise = MS [(v,1)]
                : [ (v',k) +: p' | v' <- withinFromTo v (vHalf v) (vInc v vL)
                                 , k  <- [1 .. (v `vDiv` v')]
                                 , p' <- vPart (v .-. (k *. v')) v' ]

-- | 'vHalf v' computes the \"lexicographic half\" of 'v', that is,
--   the vector which is the middle element (biased towards the end)
--   in a lexicographically decreasing list of all the vectors
--   elementwise no greater than 'v'.
vHalf :: Vec -> Vec
vHalf [] = []
vHalf (x:xs) | (even x) = (x `div` 2) : vHalf xs
             | otherwise = (x `div` 2) : xs

downFrom n = [n,(n-1)..0]

-- | 'within m' generates a lexicographically decreasing list of
--   vectors elementwise no greater than 'm'.
within :: Vec -> [Vec]
within = sequence . map downFrom

-- | Clip one vector against another.
clip :: Vec -> Vec -> Vec
clip = zipWith min

-- | 'withinFromTo m s e' efficiently generates a lexicographically
--   decreasing list of vectors which are elementwise no greater than
--   'm' and lexicographically between 's' and 'e'.
withinFromTo :: Vec -> Vec -> Vec -> [Vec]
withinFromTo m s e | not (s <|= m) = withinFromTo m (clip m s) e
withinFromTo m s e | e > s = []
withinFromTo m s e = wFT m s e True True
  where
    wFT [] _ _ _ _ = [[]]
    wFT (m:ms) (s:ss) (e:es) useS useE =
        let start = if useS then s else m
            end   = if useE then e else 0
        in
          [x:xs | x <- [start,(start-1)..end],
                  let useS' = useS && x==s,
                  let useE' = useE && x==e,
                  xs <- wFT ms ss es useS' useE' ]

-- | Efficiently generate all distinct multiset partitions.  Note that
--   each partition is represented as a multiset of parts (each of
--   which is a multiset) in order to properly reflect the fact that
--   some parts may occur multiple times.
partitions :: Multiset a -> [Multiset (Multiset a)]
partitions (MS []) = [MS []]
partitions (MS m)  = (map . fmap) (combine elts) $ vPartitions counts
  where (elts, counts) = unzip m
        combine es cs  = MS . filter ((/=0) . snd) $ zip es cs

-- | Generate all splittings of a multiset into two submultisets,
--   i.e. all size-two partitions.
splits :: Multiset a -> [(Multiset a, Multiset a)]
splits (MS [])        = [(MS [], MS [])]
splits (MS ((x,n):m)) =
  for [0..n] $ \k ->
    map (addElt x k *** addElt x (n-k)) (splits (MS m))

-- | Generate all size-k submultisets.
kSubsets :: Count -> Multiset a -> [Multiset a]
kSubsets 0 _              = [MS []]
kSubsets _ (MS [])        = []
kSubsets k (MS ((x,n):m)) =
  for [0 .. min k n] $ \j ->
    map (addElt x j) (kSubsets (k - j) (MS m))

for = flip concatMap

addElt _ 0 = id
addElt x k = ((x,k) +:)

----------------------------------------------------------------------
-- Cycles (aka Necklaces)
----------------------------------------------------------------------

-- | Generate all distinct cycles, aka necklaces, with elements taken
--   from a multiset.  See J. Sawada, \"A fast algorithm to generate
--   necklaces with fixed content\", J. Theor. Comput. Sci. 301 (2003)
--   pp. 477-489.
--
--   Given the ordering on the elements of the multiset based on their
--   position in the multiset representation (with \"smaller\"
--   elements first), in @map reverse (cycles m)@, each generated
--   cycle is lexicographically smallest among all its cyclic shifts,
--   and furthermore, the cycles occur in reverse lexicographic
--   order. (It's simply more convenient/efficient to generate the
--   cycles reversed in this way, and of course we get the same set of
--   cycles either way.)
--
--   For example, @cycles (fromList \"aabbc\") ==
--   [\"cabba\",\"bcaba\",\"cbaba\",\"bbcaa\",\"bcbaa\",\"cbbaa\"]@.
cycles :: Multiset a -> [[a]]
cycles (MS [])         = []   -- no such thing as an empty cycle
cycles m@(MS ((x1,n1):xs))
  | n1 == 1    = (cycles' n 2 1 [(0,x1)] (reverse $ zip [1..] xs))
  | otherwise =  (cycles' n 2 1 [(0,x1)] (reverse $ zip [0..] ((x1,n1-1):xs)))
  where n = sum . getCounts $ m

-- | The first parameter is the length of the necklaces being
--   generated.  The second parameter @p@ is the length of the longest
--   prefix of @pre@ which is a Lyndon word, i.e. an aperiodic
--   necklace.  @pre@ is the current (reversed) prefix of the
--   necklaces being generated.
cycles' :: Int -> Int -> Int -> [(Int, a)] -> [(Int, (a,Count))] -> [[a]]
cycles' n _ p pre [] | n `mod` p == 0 = [map snd pre]
                     | otherwise      = []

cycles' n t p pre xs =
  (takeWhile ((>=atp) . fst) xs) >>= \(j, (xj,_)) ->
    cycles' n (t+1) (if j == atp then p else t)
      ((j,xj):pre)
      (remove j xs)
  where atp = fst $ pre !! (p - 1)

remove :: Int -> [(Int, (a, Int))] -> [(Int, (a, Int))]
remove _ [] = []
remove j (x@(j',(xj,nj)):xs)
  | j == j' && nj == 1 = xs
  | j == j'            = (j',(xj,nj-1)):xs
  | otherwise          = x:remove j xs

----------------------------------------------------------------------
-- Bracelets
----------------------------------------------------------------------

-- Some utilities

--------------------------------------------------
-- Indexable and Snocable classes

class Snocable p a where
  (|>) :: p -> a -> p

-- 1-based indexing
class Indexable p where
  (!) :: p -> Int -> Int

--------------------------------------------------
-- Prenecklaces

type PreNecklace = [Int]

-- A prenecklace, stored backwards, along with its length and its
-- first element cached for quick retrieval.
data Pre = Pre !Int (Maybe Int) PreNecklace
  deriving (Show)

emptyPre :: Pre
emptyPre = Pre 0 Nothing []

getPre :: Pre -> PreNecklace
getPre (Pre _ _ as) = reverse as

instance Snocable Pre Int where
  (Pre 0 _ []) |> a  = Pre 1 (Just a) [a]
  (Pre t a1 as) |> a = Pre (t+1) a1 (a:as)

instance Indexable Pre where
  _ ! 0 = 0
  (Pre _ (Just a1) _) ! 1 = a1
  (Pre t _ as) ! i = as !! (t-i)
    -- as stores  a_t .. a_1.
    -- a_1 is the last element, i.e. with index t-1.
    -- a_2 has index t-2.
    -- In general, a_i has index t-i.

--------------------------------------------------
-- Run-length encoding

-- Run-length encodings.  Stored in *reverse* order for easy access to
-- the end.
data RLE a = RLE !Int !Int [(a,Int)]
  deriving (Show)
  -- First Int is the total length of the decoded list.
  -- Second Int is the number of blocks.

emptyRLE :: RLE a
emptyRLE = RLE 0 0 []

compareRLE :: Ord a => [(a,Int)] -> [(a,Int)] -> Ordering
compareRLE [] [] = EQ
compareRLE [] _  = LT
compareRLE _ []  = GT
compareRLE ((a1,n1):rle1) ((a2,n2):rle2)
  | (a1,n1) == (a2,n2) = compareRLE rle1 rle2
  | a1 < a2 = LT
  | a1 > a2 = GT
  | (n1 < n2 && (null rle1 || fst (head rle1) < a2)) || (n1 > n2 && not (null rle2) && a1 < fst (head rle2)) = LT
  | otherwise = GT

instance Indexable (RLE Int) where
  (RLE _ _ []) ! _ = error "Bad index in (!) for RLE"
  (RLE n b ((a,v):rest)) ! i
    | i <= v = a
    | otherwise = (RLE (n-v) (b-1) rest) ! (i-v)

instance Eq a => Snocable (RLE a) a where
  (RLE _ _ []) |> a' = RLE 1 1 [(a',1)]
  (RLE n b rle@((a,v):rest)) |> a'
    | a == a'   = RLE (n+1) b     ((a,v+1):rest)
    | otherwise = RLE (n+1) (b+1) ((a',1):rle)

--------------------------------------------------
-- Prenecklaces + RLE

-- Prenecklaces along with a run-length encoding.
data Pre' = Pre' Pre (RLE Int)
  deriving Show

emptyPre' :: Pre'
emptyPre' = Pre' emptyPre emptyRLE

getPre' :: Pre' -> PreNecklace
getPre' (Pre' pre _) = getPre pre

instance Indexable Pre' where
  _ ! 0 = 0
  (Pre' (Pre len _ _) rle) ! i = rle ! (len - i + 1)

instance Snocable Pre' Int where
  (Pre' p rle) |> a = Pre' (p |> a) (rle |> a)

--------------------------------------------------
-- Bracelet generation

type Bracelet = [Int]

-- | An optimized bracelet generation algorithm, based on
--   S. Karim et al, "Generating Bracelets with Fixed Content".
--   <http://www.cis.uoguelph.ca/~sawada/papers/fix-brace.pdf>
--
--   @genFixedBracelets n content@ produces all bracelets (unique up
--   to rotation and reflection) of length @n@ using @content@, which
--   consists of a list of pairs where the pair (a,i) indicates that
--   element a may be used up to i times.  It is assumed that the elements
--   are drawn from [0..k].
genFixedBracelets :: Int -> [(Int,Int)] -> [Bracelet]
genFixedBracelets n [(0,k)] | k >= n = [replicate k 0]
                            | otherwise = []
genFixedBracelets n content = execWriter (go 1 1 0 (IM.fromList content) emptyPre')
  where
    go :: Int -> Int -> Int -> IM.IntMap Int -> Pre' -> Writer [Bracelet] ()
    go _ _ _ con _ | IM.keys con == [0] = return ()
    go t p r con pre@(Pre' (Pre _ _ as) _)
      | t > n =
          when (take (n - r) as >= reverse (take (n-r) as) && n `mod` p == 0) $
            tell [getPre' pre]
      | otherwise = do
          let a' = pre ! (t-p)
          forM_ (dropWhile (< a') $ IM.keys con) $ \j -> do
            let con' = decrease j con
                pre' = pre |> j
                c = checkRev2 t pre'
                p' | j /= a'   = t
                   | otherwise = p
            when (c == EQ) $ go (t+1) p' t con' pre'
            when (c == GT) $ go (t+1) p' r con' pre'

    decrease :: Int -> IM.IntMap Int -> IM.IntMap Int
    decrease j con
      | IM.null con = con
      | otherwise   = IM.alter q j con
      where
        q (Just 1)   = Nothing
        q (Just cnt) = Just (cnt-1)
        q _          = Nothing

    checkRev2 _ (Pre' _ (RLE _ _ rle)) = compareRLE rle (reverse rle)

-- | Generate all distinct bracelets (lists considered equivalent up
--   to rotation and reversal) from a given multiset.  The generated
--   bracelets are in lexicographic order, and each is
--   lexicographically smallest among its rotations and reversals.
--   See @genFixedBracelets@ for a slightly more general routine with
--   references.
--
--   For example, @bracelets $ fromList \"RRRRRRRLLL\"@ yields
--
--   > ["LLLRRRRRRR","LLRLRRRRRR","LLRRLRRRRR","LLRRRLRRRR"
--   > ,"LRLRLRRRRR","LRLRRLRRRR","LRLRRRLRRR","LRRLRRLRRR"]
bracelets :: Multiset a -> [[a]]
bracelets ms@(MS cnts) = bs
  where
    contentMap = IM.fromList (zip [0..] (map fst cnts))
    content    = zipWith (\i (_,n) -> (i,n)) [0..] cnts
    rawBs      = genFixedBracelets (size ms) content
    bs         = map (map (fromJust . flip IM.lookup contentMap)) rawBs

----------------------------------------------------------------------
-- sequenceMS
----------------------------------------------------------------------

-- | Take a multiset of lists, and select one element from each list
--   in every possible combination to form a list of multisets.  We
--   assume that all the list elements are distinct.
sequenceMS :: Multiset [a] -> [Multiset a]
sequenceMS = map disjUnions
           . sequence
           . map (\(xs, n) -> kSubsets n (MS $ uncollate (xs, n)))
           . toCounts

uncollate :: ([a], Count) -> [(a, Count)]
uncollate (xs, n) = map (\x -> (x,n)) xs