{-# LANGUAGE CPP        #-}

module Agda.Utils.Memo where

import Control.Monad.State
import System.IO.Unsafe
import Data.IORef
import qualified Data.Map as Map
import qualified Agda.Utils.HashMap as HMap
import Data.Hashable

import Agda.Utils.Lens

-- Simple memoisation in a state monad

-- | Simple, non-reentrant memoisation.
memo :: MonadState s m => Lens' (Maybe a) s -> m a -> m a
memo tbl compute = do
  mv <- use tbl
  case mv of
    Just x  -> return x
    Nothing -> do
      x <- compute
      x <$ (tbl .= Just x)

-- | Recursive memoisation, second argument is the value you get
--   on recursive calls.
memoRec :: MonadState s m => Lens' (Maybe a) s -> a -> m a -> m a
memoRec tbl ih compute = do
  mv <- use tbl
  case mv of
    Just x  -> return x
    Nothing -> do
      tbl .= Just ih
      x <- compute
      x <$ (tbl .= Just x)

{-# NOINLINE memoUnsafe #-}
memoUnsafe :: Ord a => (a -> b) -> (a -> b)
memoUnsafe f = unsafePerformIO $ do
  tbl <- newIORef Map.empty
  return (unsafePerformIO . f' tbl)
  where
    f' tbl x = do
      m <- readIORef tbl
      case Map.lookup x m of
        Just y  -> return y
        Nothing -> do
          let y = f x
          writeIORef tbl (Map.insert x y m)
          return y

{-# NOINLINE memoUnsafeH #-}
memoUnsafeH :: (Eq a, Hashable a) => (a -> b) -> (a -> b)
memoUnsafeH f = unsafePerformIO $ do
  tbl <- newIORef HMap.empty
  return (unsafePerformIO . f' tbl)
  where
    f' tbl x = do
      m <- readIORef tbl
      case HMap.lookup x m of
        Just y  -> return y
        Nothing -> do
          let y = f x
          writeIORef tbl (HMap.insert x y m)
          return y