-- | Mobius inversion for the coarsening poset of partitions

{-# LANGUAGE BangPatterns, TypeSynonymInstances, FlexibleInstances #-}
module Math.RootLoci.Geometry.Mobius 
  ( Partition(..) 
  -- * The refinement poset of partitions
  , coarserThan , finerThan
  , (.==.) , (./=.) , (.<=.) , (.>=.) , (.<.) , (.>.) 
  -- * closures
  , fastClosure , fastAntiClosure
  , closureSet , closureSet'
  -- * Mobius function
  , zetaOf , mobiusOf
  -- * helpers
  , firstLevelDown , firstLevelUp  
  -- * set partitions
  , closureSetOfSetPartition
  , firstLevelDownSetP
  )
  where

--------------------------------------------------------------------------------

import Data.List

import qualified Data.Map.Strict as Map ; import Data.Map.Strict (Map)
import qualified Data.Set        as Set ; import Data.Set        (Set)

import Math.Combinat.Partitions.Integer
import Math.Combinat.Partitions.Set
import Math.Combinat.Sets

import qualified Math.RootLoci.Algebra.FreeMod as ZMod

import Math.RootLoci.Algebra
import Math.RootLoci.Misc

--------------------------------------------------------------------------------

{-
indicator :: Bool -> Integer
indicator b = if b then 1 else 0

kronecker' :: Partition -> ZMod Partition
kronecker' p = ZMod.singleton p 1

kronecker :: Partition -> Partition -> Integer
kronecker p q = indicator (p .==. q)

zeta :: Partition -> Partition -> Integer
zeta p q = indicator (p .<=. q)
-}

--------------------------------------------------------------------------------
-- * Mobius function

-- | Zeta function of the refinement poset
zetaOf :: Partition -> ZMod Partition
zetaOf = pcache calc where
  calc p = ZMod.fromList $ map (\p -> (p,1)) $ Set.toList $ closureSet p

-- | Mobius function of the refinement poset
mobiusOf :: Partition -> ZMod Partition
mobiusOf = pcache calc where
  calc    p = ZMod.sub (ZMod.singleton p 1) (smaller p)
  smaller p = ZMod.sum [ mobiusOf q | q <- Set.toList (closureSet' p) ]

--------------------------------------------------------------------------------
-- * The refinement poset of partitions

coarserThan :: Partition -> Partition -> Bool
coarserThan p q = Set.member p (closureSet q)

finerThan :: Partition -> Partition -> Bool
finerThan q p = coarserThan p q

(.<=.) :: Partition -> Partition -> Bool
(.<=.) = coarserThan

(.>=.) :: Partition -> Partition -> Bool
(.>=.) = finerThan

(.==.) :: Partition -> Partition -> Bool
(.==.) = (==)

(./=.) :: Partition -> Partition -> Bool
(./=.) = (/=)

(.<.) :: Partition -> Partition -> Bool
(.<.) p q = (p .<=. q) && (p /= q) 

(.>.) :: Partition -> Partition -> Bool
(.>.) p q = (p .>=. q) && (p /= q) 

--------------------------------------------------------------------------------
-- | Efficient first level merge/split

insertRevSorted :: Int -> [Int] -> [Int]
insertRevSorted x = go where
  go yys@(y:ys) = if x >= y then x : yys else y : go ys
  go []         = [x]

insertRevSorted2 :: Int -> Int -> [Int] -> [Int]
insertRevSorted2 x y = insertRevSorted x . insertRevSorted y

-- | Example: 
-- 
-- > insertGroup [3,3] [[5,5,5],[4],[1,1,1,1]] == [5,5,5,4,3,3,1,1,1,1]
--
insertGroup_ :: [Int] -> [[Int]] -> [Int]
insertGroup_ zs@(z:_) = go where
  go (xs@(x:_):rest) = if z >= x then zs ++ xs ++ concat rest 
                                 else xs ++ go rest
  go ([]      :rest) = go rest
  go []              = zs
insertGroup_ [] = concat

-- | These satisfy:
--
-- > concat . insertGroup what == insertGroup_ what
--
insertGroup :: [Int] -> [[Int]] -> [[Int]]
insertGroup zs@(z:_) = go where
  go (xs@(x:_):rest) = if z >= x then zs : xs : rest 
                                 else xs : go rest
  go ([]      :rest) = go rest
  go []              = [zs]
insertGroup [] = id

insertGroup2_ :: [Int] -> [Int] -> [[Int]] -> [Int]
insertGroup2_ xs ys = insertGroup_ xs . insertGroup ys

insertGroup2 :: [Int] -> [Int] -> [[Int]] -> [[Int]]
insertGroup2 xs ys = insertGroup xs . insertGroup ys

choose1 :: [a] -> [(a,[a])]
choose1 (x:xs) = (x,xs) : [ (y,x:ys) | (y,ys) <- choose1 xs ]
choose1 []     = []

choose2 :: [a] -> [(a,a,[a])]
choose2 (x:xs) =  [ (x,y,ys  ) |   (y,ys) <- choose1 xs ]
               ++ [ (y,z,x:zs) | (y,z,zs) <- choose2 xs ]
choose2 []     =  []

-- | Merging two parts
firstLevelDown :: Partition -> [Partition]
firstLevelDown (Partition ps) = (one ++ two) where
  gs  = group ps
  one = [ Partition $ insertRevSorted (x+y) (insertGroup_  zs    rest) | ((x:y:zs)     ,rest) <- choose1 gs ]
  two = [ Partition $ insertRevSorted (x+y) (insertGroup2_ xs ys rest) | ((x:xs),(y:ys),rest) <- choose2 gs ]

-- | Splitting one part into two
firstLevelUp :: Partition -> [Partition]
firstLevelUp (Partition ps) = one where
  gs  = group ps
  one = [ Partition $ insertRevSorted2 x (z-x) (insertGroup_  zs rest) | ((z:zs),rest) <- choose1 gs , x<-[1..div z 2] ]

-- | Sanity check
firstLevelDownNaive :: Partition -> [Partition]
firstLevelDownNaive (Partition ps) = unique [ mkPartition ( x+y : zs ) | ([x,y],zs) <- choose' 2 ps ]

firstLevelUpNaive :: Partition -> [Partition]
firstLevelUpNaive (Partition ps) = unique [ mkPartition ( x : z-x : zs ) | ([z],zs) <- choose' 1 ps , x<-[1..z-1] ]

checkDown :: Partition -> Bool
checkDown p = (sort (firstLevelDown p) == firstLevelDownNaive p)

checkUp :: Partition -> Bool
checkUp p = (sort (firstLevelUp p) == firstLevelUpNaive p)

--------------------------------------------------------------------------------

-- | Fast computation of a single closure
fastClosure :: Partition -> Set Partition
fastClosure p = go Set.empty [p] where
  go !acc (p:ps) = case Set.member p acc of
    True  -> go acc ps
    False -> go (Set.insert p acc) (firstLevelDown p ++ ps)
  go !acc []     = acc

-- | Fast computation of a single \"anticlosure\" (opposite poset)
fastAntiClosure :: Partition -> Set Partition
fastAntiClosure p = go Set.empty [p] where
  go !acc (p:ps) = case Set.member p acc of
    True  -> go acc ps
    False -> go (Set.insert p acc) (firstLevelUp p ++ ps)
  go !acc []     = acc

--------------------------------------------------------------------------------

-- | Caches and reuses all closures (lazily), this is the fastest version
closureSet :: Partition -> Set Partition 
closureSet = cached where
  cached = monoCache calc 
  calc p = go (Set.singleton p) (firstLevelDown p) where
    go !acc (p:ps) = case Set.member p acc of
      True  -> go acc ps
      False -> go (Set.union acc (cached p)) ps
    go !acc []     = acc

-- | The closure without the stratum itself
closureSet' :: Partition -> Set Partition
closureSet' p = Set.delete p (closureSet p)

--------------------------------------------------------------------------------
-- * set partitions

firstLevelDownSetP :: SetPartition -> [SetPartition]
firstLevelDownSetP (SetPartition ps) =
  [ toSetPartition ( (x++y) : zs ) | ([x,y],zs) <- choose' 2 ps ]
  
closureSetOfSetPartition :: SetPartition -> Set SetPartition  
closureSetOfSetPartition = cached where
  cached = monoCache calc
  calc p = go (Set.singleton p) (firstLevelDownSetP p) where
    go !acc (p:ps) = case Set.member p acc of
      True  -> go acc ps
      False -> go (Set.union acc (cached p)) ps
    go !acc []     = acc
 
--------------------------------------------------------------------------------