-- | Some auxilary functions

{-# LANGUAGE BangPatterns, TypeSynonymInstances, FlexibleInstances, DeriveFunctor #-}
module Math.RootLoci.Misc.Common where

--------------------------------------------------------------------------------

import Data.List
import Data.Monoid
import Data.Ratio

import Control.Monad

import Math.Combinat.Numbers
import Math.Combinat.Sign
import Math.Combinat.Partitions.Integer 
import Math.Combinat.Partitions.Set
import Math.Combinat.Sets

import qualified Data.Map.Strict as Map
import Data.Map (Map)

-- import qualified Math.RootLoci.Algebra.FreeMod as ZMod
-- import Math.RootLoci.Algebra.SymmPoly
-- import Math.RootLoci.Geometry.Cohomology

--------------------------------------------------------------------------------
-- * Pairs

data Pair a 
  = Pair a a 
  deriving (Eq,Ord,Show,Functor)

--------------------------------------------------------------------------------
-- * Lists

{-# SPECIALIZE unique :: [Partition] -> [Partition] #-}
unique :: Ord a => [a] -> [a]
unique = map head . group . sort

-- | Synonym for histogram
count :: Ord b => [b] -> Map b Integer
count = histogram

{-# SPECIALIZE histogram :: [Partition] -> Map Partition Integer #-}
histogram :: Ord b => [b] -> Map b Integer
histogram xs = foldl' f Map.empty xs where
  f old x = Map.insertWith (+) x 1 old

--------------------------------------------------------------------------------
-- * Maps
  
deleteLookup :: Ord a => a -> Map a b -> (Maybe b, Map a b)
deleteLookup k table = (Map.lookup k table, Map.delete k table)  

unsafeDeleteLookup :: Ord a => a -> Map a b -> (b, Map a b)
unsafeDeleteLookup k table = (fromJust (Map.lookup k table), Map.delete k table) where
  fromJust mb = case mb of
    Just y  -> y
    Nothing -> error "unsafeDeleteLookup: key not found"

--------------------------------------------------------------------------------
-- * Partitions

-- | @aut(mu)@ is the number of symmetries of the partition mu:
--
-- > aut(mu) = prod_r (e_r)!
--
-- where @mu = (1^e1 2^e2 .. k^ek)@
aut :: Partition -> Integer
aut part = product $ map factorial es where
  es = map snd $ toExponentialForm part   

--------------------------------------------------------------------------------
-- * Set partitions
 
-- | Makes set partition from a partition (simply filling up from left to right)
-- with the shape giving back the input partition
defaultSetPartition :: Partition -> SetPartition
defaultSetPartition = SetPartition . linearIndices

-- | Produce linear indices from a partition @nu@ (to encode the diagonal map @Delta_nu@).
linearIndices :: Partition -> [[Int]]
linearIndices (Partition ps) = go 0 ps where
  go _  []     = []
  go !k (a:as) = [k+1..k+a] : go (k+a) as

--------------------------------------------------------------------------------
-- * Signs

class IsSigned a where
  signOf :: a -> Maybe Sign

signOfNum :: (Ord a, Num a) => a -> Maybe Sign 
signOfNum x = case compare x 0 of
  LT -> Just Minus
  GT -> Just Plus
  EQ -> Nothing

instance IsSigned Int      where signOf = signOfNum
instance IsSigned Integer  where signOf = signOfNum
instance IsSigned Rational where signOf = signOfNum

--------------------------------------------------------------------------------
-- * Numbers

fromRat :: Rational -> Integer
fromRat r = case denominator r of
  1 -> numerator r
  _ -> error "fromRat: not an integer"    

safeDiv :: Integer -> Integer -> Integer
safeDiv a b = case divMod a b of
  (q,0) -> q
  (q,r) -> error $ "saveDiv: " ++ show a ++ " = " ++ show b ++ " * " ++ show q ++ " + " ++ show r

--------------------------------------------------------------------------------
-- * Combinatorics

-- | Chooses (n-1) elements out of n
chooseN1 :: [a] -> [[a]]
chooseN1 = go where
  go (x:xs) = xs : map (x:) (go xs)
  go []     = []
  
symPolyNum :: Num a => Int -> [a] -> a
symPolyNum k xs = sum' (map prod' $ choose k xs) where
  sum'  = foldl' (+) 0
  prod' = foldl' (*) 1

--------------------------------------------------------------------------------