-- | Infinite lazy partition tables (used for caching).
--
-- We cache almost all computations (which would be otherwise typically 
-- executed many times); this really helps performance.
--

{-# LANGUAGE Rank2Types #-} 
module Math.RootLoci.Misc.PTable where

--------------------------------------------------------------------------------

import Data.List

import Math.Combinat.Classes
import Math.Combinat.Partitions.Integer
import Math.Combinat.Partitions.Set

import qualified Data.Map.Lazy as LMap

import Math.RootLoci.Algebra.SymmPoly

--------------------------------------------------------------------------------
-- * Finite lazy partition tables

newtype PTable a = PTable (LMap.Map Partition a)

createPTable :: (Partition -> a) -> Int -> PTable a
createPTable f n = PTable $ LMap.fromList [ (p, f p) | p <- partitions n ]

lookupPTable :: Partition -> PTable a -> a
lookupPTable p (PTable lmap) = case LMap.lookup p lmap of
  Just y  -> y
  Nothing -> error "lookupPTable"

--------------------------------------------------------------------------------
-- * Infinite lazy partition tables

newtype PSeries a = PSeries [PTable a]
  
createPSeries :: (Partition -> a) -> PSeries a
createPSeries f = PSeries [ createPTable f n | n<-[0..] ]

lookupPSeries :: Partition -> PSeries a -> a
lookupPSeries part (PSeries series) = lookupPTable part (series !! partitionWeight part)

--------------------------------------------------------------------------------
-- * Finite lazy set-partition tables

newtype SetPTable a = SetPTable (LMap.Map SetPartition a)

createSetPTable :: (SetPartition -> a) -> Int -> SetPTable a
createSetPTable f n = SetPTable $ LMap.fromList [ (p, f p) | p <- setPartitions n ]

lookupSetPTable :: SetPartition -> SetPTable a -> a
lookupSetPTable p (SetPTable lmap) = case LMap.lookup p lmap of
  Just y  -> y
  Nothing -> error "lookupSetPTable"

--------------------------------------------------------------------------------
-- * Infinite lazy set-partition tables

newtype SetPSeries a = SetPSeries [SetPTable a]
  
createSetPSeries :: (SetPartition -> a) -> SetPSeries a
createSetPSeries f = SetPSeries [ createSetPTable f n | n<-[0..] ]

lookupSetPSeries :: SetPartition -> SetPSeries a -> a
lookupSetPSeries setp (SetPSeries series) = lookupSetPTable setp (series !! setpWeight setp) where
  setpWeight (SetPartition ps) = foldl' (+) 0 (map length ps)

--------------------------------------------------------------------------------
-- * polymorphic caching 

polyCache1 
  :: (CacheKey key) 
  => (forall base. ChernBase base => key -> f base)     -- ^ polymorphic function to be cached
  -> (forall base. ChernBase base => key -> f base)
polyCache1 calc = \key -> select1 (cacheAB key, cacheChern key)  where
  cacheAB    = monoCache $ \k -> spec1' ChernRoot  (calc k)
  cacheChern = monoCache $ \k -> spec1' ChernClass (calc k)

polyCache2 
  :: (CacheKey key) 
  => (forall base. ChernBase base => key -> f (g base))     -- ^ polymorphic function to be cached
  -> (forall base. ChernBase base => key -> f (g base))
polyCache2 calc = \key -> select2 (cacheAB key, cacheChern key)  where
  cacheAB    = monoCache $ \k -> spec2' ChernRoot  (calc k)
  cacheChern = monoCache $ \k -> spec2' ChernClass (calc k)

polyCache3 
  :: (CacheKey key) 
  => (forall base. ChernBase base => key -> f (g (h base)))     -- ^ polymorphic function to be cached
  -> (forall base. ChernBase base => key -> f (g (h base)))
polyCache3 calc = \key -> select3 (cacheAB key, cacheChern key)  where
  cacheAB    = monoCache $ \k -> spec3' ChernRoot  (calc k)
  cacheChern = monoCache $ \k -> spec3' ChernClass (calc k)

--------------------------------------------------------------------------------
-- * monomorphic caching 

class CacheKey key where
  monoCache :: (key -> a) -> (key -> a)

instance CacheKey Int          where  monoCache = icache
instance CacheKey Partition    where  monoCache = pcache
instance CacheKey SetPartition where  monoCache = setpcache

--------------------------------------------------------------------------------
-- * individual caching functions

pcache :: (Partition -> a) -> (Partition -> a)
pcache calc = lkp where
  lkp p = lookupPSeries p table 
  table = createPSeries calc

setpcache :: (SetPartition -> a) -> (SetPartition -> a)
setpcache calc = lkp where
  lkp setp = lookupSetPSeries setp table 
  table    = createSetPSeries calc

icache :: (Int -> a) -> (Int -> a)
icache calc = \n -> (table !! n) where
  table = [ calc i | i <- [0..]  ]

icache' :: a -> Int -> (Int -> a) -> (Int -> a)
icache' dflt fstidx calc = \n -> if n < fstidx then dflt else (table !! (n-fstidx)) where
  table = [ calc i | i <- [fstidx..]  ]

--------------------------------------------------------------------------------