{-# LANGUAGE PackageImports
            ,FlexibleInstances
            ,MultiParamTypeClasses
            ,UndecidableInstances
            ,TypeSynonymInstances
            #-}
-- | Memoisation.
-- It's useful for dynamic programming.
module Data.Function.YaMemo (
  -- * Class
    MemoTable(..)
  -- * Type
  , Memo
  -- * Function
  , memo) where

import "mtl" Control.Monad.State
import qualified Data.Map as M

class MemoTable t where
  emptyMemoTable  :: Ord a => t a b
  lookupMemoTable :: Ord a => a -> t a b -> Maybe b
  insertMemoTable :: Ord a => a -> b -> t a b -> t a b

class (Monad m) => MemoTableT t m where
  emptyMemoTableT  :: Ord a => t a (m b)
  lookupMemoTableT :: Ord a => a -> t a (m b) -> Maybe (m b)
  insertMemoTableT :: Ord a => a -> m b -> t a (m b) -> t a (m b)

instance MemoTable M.Map where
  emptyMemoTable  = M.empty
  lookupMemoTable = M.lookup
  insertMemoTable = M.insert

instance MemoTableT M.Map [] where
  emptyMemoTableT  = M.empty
  lookupMemoTableT = M.lookup
  insertMemoTableT = M.insert

{-
instance (MemoTable t, Ord a, Num b, Eq b) => Eq (State (t a b) b) where
  sx == sy = evalState sx emptyMemoTable == evalState sy emptyMemoTable

instance (MemoTable t, Ord a, Num b, Show b) => Show (State (t a b) b) where
  show sx = show (evalState sx emptyMemoTable)

instance (MemoTable t, Ord a, Num b) => Num (State (t a b) b) where
  (+)    = liftM2 (+)
  (-)    = liftM2 (-)
  (*)    = liftM2 (*)
  negate = liftM negate
  abs    = liftM abs
  signum = liftM signum
  fromInteger = return . fromInteger

instance (MemoTable t, Ord a, Monoid b) => Monoid (State (t a b) b) where
  mempty   = return mempty
  mappend  = liftM2 mappend
-}

type Memo t a b = a -> State (t a b) b

memoise :: (MemoTable t, Ord a) => Memo t a b -> Memo t a b
memoise mf x = do prev <- find x
                  case prev of
                    Just y  -> return y
                    Nothing -> do y    <- mf x
                                  ins x y
                                  return y
               where find k  = get >>= return . lookupMemoTable k
                     ins k v = get >>= put . insertMemoTable k v

evalMemo :: (MemoTable t, Ord a) => (Memo t) a b -> (->) a b
evalMemo m v = evalState (m v) emptyMemoTable

gfun :: (b -> c) -> (c -> b) -> c
gfun = (fix .) . (.)

memoising :: (Ord a, MemoTable t)
	  => ((a -> State (t a b) b) -> Memo t a b) -> a -> State (t a b) b
memoising = gfun memoise

memo :: (MemoTable t, Ord a)
     => (a -> State (t a b) b)
     -> ((a -> State (t a b) b) -> Memo t a b)
     -> (a -> b)
memo g f = evalMemo (asTypeOf (memoising f) g)