module Data.StableMemo.Internal (Ref (..), memo) where
import Data.Proxy
import System.Mem.StableName
import Data.HashTable.IO (BasicHashTable)
import System.IO.Unsafe (unsafePerformIO)
import System.Mem.Weak (Weak)
import qualified Data.HashTable.IO as HashTable
import qualified System.Mem.Weak as Weak
type SNMap a b = BasicHashTable (StableName a) b
type MemoTable ref a b = SNMap a (ref b)
class Ref ref where
mkRef :: a -> b -> IO () -> IO (ref b)
deRef :: ref a -> IO (Maybe a)
finalize :: ref a -> IO ()
instance Ref Weak where
mkRef x y = Weak.mkWeak x y . Just
deRef = Weak.deRefWeak
finalize = Weak.finalize
finalizer :: StableName a -> Weak (MemoTable ref a b) -> IO ()
finalizer sn weakTbl = do
r <- Weak.deRefWeak weakTbl
case r of
Nothing -> return ()
Just tbl -> HashTable.delete tbl sn
memo' :: Ref ref => Proxy ref -> (a -> b) -> MemoTable ref a b -> Weak (MemoTable ref a b) -> (a -> b)
memo' _ f tbl weakTbl !x = unsafePerformIO $ do
sn <- makeStableName x
lkp <- HashTable.lookup tbl sn
case lkp of
Nothing -> notFound sn
Just w -> do
maybeVal <- deRef w
case maybeVal of
Nothing -> notFound sn
Just val -> return val
where notFound sn = do
let y = f x
weak <- mkRef x y $ finalizer sn weakTbl
HashTable.insert tbl sn weak
return y
tableFinalizer :: Ref ref => MemoTable ref a b -> IO ()
tableFinalizer = HashTable.mapM_ $ finalize . snd
memo :: Ref ref => Proxy (ref :: * -> *) -> (a -> b) -> (a -> b)
memo p f =
let (tbl, weak) = unsafePerformIO $ do
tbl' <- HashTable.new
weak' <- Weak.mkWeakPtr tbl . Just $ tableFinalizer tbl
return (tbl', weak')
in memo' p f tbl weak