{-# LANGUAGE GeneralizedNewtypeDeriving, ExistentialQuantification, RankNTypes #-} module Data.SNMap ( SNMap, SNMapReaderT, runSNMapReaderT, newSNMap, memoize, memoizeM )where import System.Mem.StableName import qualified Data.HashTable.IO as HT import Data.Functor import Control.Monad.IO.Class (liftIO, MonadIO) import Control.Monad.Trans.Class import System.Mem.Weak (addFinalizer) import Control.Monad.Trans.Reader (ReaderT, ask, runReaderT) import Control.Applicative (Applicative) import Control.Monad.Exception (MonadException, MonadAsyncException) newtype SNMap m a = SNMap (HT.BasicHashTable (StableName (m a)) a) newSNMap :: IO (SNMap m a) newSNMap = SNMap <$> HT.new memoize :: MonadIO m => SNMap m a -> m a -> m a memoize (SNMap h) m = do s <- liftIO $ makeStableName $! m x <- liftIO $ HT.lookup h s case x of Just a -> return a Nothing -> do a <- m liftIO $ do HT.insert h s a addFinalizer m (HT.delete h s) return a newtype SNMapReaderT a m b = SNMapReaderT (ReaderT (SNMap (SNMapReaderT a m) a) m b) deriving (Functor, Applicative, Monad, MonadIO, MonadException, MonadAsyncException) runSNMapReaderT :: MonadIO m => SNMapReaderT a m b -> m b runSNMapReaderT (SNMapReaderT m) = do h <- liftIO newSNMap runReaderT m h instance MonadTrans (SNMapReaderT a) where lift = SNMapReaderT . lift memoizeM :: MonadIO m => SNMapReaderT a m a -> SNMapReaderT a m a memoizeM m = do h <- SNMapReaderT ask memoize h m