{-# LANGUAGE CPP #-}
module Agda.Utils.Memo where
import Control.Monad.State
import System.IO.Unsafe
import Data.IORef
import qualified Data.Map as Map
import qualified Agda.Utils.HashMap as HMap
import Data.Hashable
import Agda.Utils.Lens
memo :: MonadState s m => Lens' (Maybe a) s -> m a -> m a
memo tbl compute = do
mv <- use tbl
case mv of
Just x -> return x
Nothing -> do
x <- compute
x <$ (tbl .= Just x)
memoRec :: MonadState s m => Lens' (Maybe a) s -> a -> m a -> m a
memoRec tbl ih compute = do
mv <- use tbl
case mv of
Just x -> return x
Nothing -> do
tbl .= Just ih
x <- compute
x <$ (tbl .= Just x)
{-# NOINLINE memoUnsafe #-}
memoUnsafe :: Ord a => (a -> b) -> (a -> b)
memoUnsafe f = unsafePerformIO $ do
tbl <- newIORef Map.empty
return (unsafePerformIO . f' tbl)
where
f' tbl x = do
m <- readIORef tbl
case Map.lookup x m of
Just y -> return y
Nothing -> do
let y = f x
writeIORef tbl (Map.insert x y m)
return y
{-# NOINLINE memoUnsafeH #-}
memoUnsafeH :: (Eq a, Hashable a) => (a -> b) -> (a -> b)
memoUnsafeH f = unsafePerformIO $ do
tbl <- newIORef HMap.empty
return (unsafePerformIO . f' tbl)
where
f' tbl x = do
m <- readIORef tbl
case HMap.lookup x m of
Just y -> return y
Nothing -> do
let y = f x
writeIORef tbl (HMap.insert x y m)
return y