{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module Control.Search.MemoReader where

import Control.Search.Memo

import Data.Map (Map)
import qualified Data.Map as Map

import Control.Monatron.Monatron hiding (Abort, L, state, cont)
import Control.Monatron.Zipper hiding (i,r)
import Control.Monatron.MonadInfo
import Control.Monatron.IdT

newtype MemoReaderT r m a = MemoReaderT { unMemoReaderT :: Int -> ReaderT r m a }

instance MonadT (MemoReaderT r) where
  lift m = MemoReaderT $ const $ lift m
  tbind (MemoReaderT i) f = MemoReaderT (\n -> i n `tbind` (\r -> unMemoReaderT (f r) n))

instance MonadInfoT (MemoReaderT r) where
  tminfo x = miInc "MemoReaderT" (minfo $ runReaderT undefined (unMemoReaderT x 0))

instance FMonadT (MemoReaderT s) where
  tmap' d1 d2 g f (MemoReaderT m) = MemoReaderT (tmap' d1 d2 g f . m)

memoReaderT :: MemoM m => (e -> Int -> m a) -> MemoReaderT e m a
memoReaderT f = MemoReaderT (\n -> readerT (\e -> f e n))

deMemoReaderT :: MemoM m => e -> Int -> MemoReaderT e m a -> m a
deMemoReaderT e i (MemoReaderT f) = runReaderT e (f i)

runMemoReaderT :: (MemoM m, Show s) => s -> MemoReaderT s m a -> m a
runMemoReaderT s r = 
  do x1 <- getMemo
     let l = Map.size (memoRead x1)
     setMemo x1 { memoRead = Map.insert l (show s) $ memoRead x1 }
     r <- deMemoReaderT s l r
     x2 <- getMemo
     setMemo x2 { memoRead = Map.delete l $ memoRead x2 }
     return r

modelMemoReaderT :: (Show s, MemoM m) => Model (ReaderOp s) (MemoReaderT s m)
modelMemoReaderT (Ask g)     = memoReaderT (\s n -> deMemoReaderT s n (g s))
modelMemoReaderT (InEnv s a) = memoReaderT (\_ n -> deMemoReaderT s n (do { m1 <- getMemo
                                                                          ; let oldVal = memoRead m1 Map.! n
                                                                          ; setMemo m1 { memoRead = Map.insert n (show s) (memoRead m1) }
                                                                          ; x <- a
                                                                          ; m2 <- getMemo
                                                                          ; setMemo m2 { memoRead = Map.insert n oldVal (memoRead m2) }
                                                                          ; return x
                                                                          }
                                                                      )
                                           )

instance (MemoM m, Show s) => ReaderM s (MemoReaderT s m) where
  readerModel = modelMemoReaderT