{-# LANGUAGE BangPatterns #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE TypeOperators #-} module Data.StableMemo.Internal (Ref (..), Strong (..), (-->) (), memo) where import Data.Proxy import System.Mem.StableName import Data.HashTable.IO (BasicHashTable) import GHC.Prim (Any) import System.IO.Unsafe (unsafePerformIO) import System.Mem.Weak (Weak) import Unsafe.Coerce (unsafeCoerce) import qualified Data.HashTable.IO as HashTable import qualified System.Mem.Weak as Weak newtype (f <<< g) a = O { unO :: f (g a) } -- Invariant: The type parameters for a key and its corresponding -- value are the same. type SNMap f g = BasicHashTable (StableName (f Any)) (g Any) type MemoTable ref f g = SNMap f (ref <<< g) class Ref ref where mkRef :: a -> b -> IO () -> IO (ref b) deRef :: ref a -> IO (Maybe a) finalize :: ref a -> IO () instance Ref Weak where mkRef x y = Weak.mkWeak x y . Just deRef = Weak.deRefWeak finalize = Weak.finalize data Strong a = Strong a !(Weak a) instance Ref Strong where mkRef _ y final = do weak <- Weak.mkWeakPtr y $ Just final return $ Strong y weak deRef (Strong x _) = return $ Just x finalize (Strong _ weak) = Weak.finalize weak finalizer :: StableName (f Any) -> Weak (MemoTable ref f g) -> IO () finalizer sn weakTbl = do r <- Weak.deRefWeak weakTbl case r of Nothing -> return () Just tbl -> HashTable.delete tbl sn unsafeToAny :: f a -> f Any unsafeToAny = unsafeCoerce unsafeFromAny :: f Any -> f a unsafeFromAny = unsafeCoerce -- | Polymorphic memoizable function type f --> g = forall a. f a -> g a memo' :: Ref ref => Proxy ref -> (f --> g) -> MemoTable ref f g -> Weak (MemoTable ref f g) -> (f --> g) memo' _ f tbl weakTbl !x = unsafePerformIO $ do sn <- makeStableName $ unsafeToAny x lkp <- HashTable.lookup tbl sn case lkp of Nothing -> notFound sn Just (O w) -> do maybeVal <- deRef w case maybeVal of Nothing -> notFound sn Just val -> return $ unsafeFromAny val where notFound sn = do let y = f x weak <- mkRef x (unsafeToAny y) $ finalizer sn weakTbl HashTable.insert tbl sn $ O weak return y tableFinalizer :: Ref ref => MemoTable ref f g -> IO () tableFinalizer = HashTable.mapM_ $ finalize . unO . snd memo :: Ref ref => Proxy (ref :: * -> *) -> (f --> g) -> (f --> g) memo p f = let (tbl, weak) = unsafePerformIO $ do tbl' <- HashTable.new weak' <- Weak.mkWeakPtr tbl . Just $ tableFinalizer tbl return (tbl', weak') in memo' p f tbl weak