{- | Module : Control.Monad.Memo.Class Copyright : (c) Eduard Sergeev 2011 License : BSD-style (see the file LICENSE) Maintainer : eduard.sergeev@gmail.com Stability : experimental Portability : non-portable (multi-param classes, functional dependencies) [Computation type:] Interface for monadic computations which can be memoized. -} {-# LANGUAGE NoImplicitPrelude, TupleSections, MultiParamTypeClasses, FunctionalDependencies, UndecidableInstances, FlexibleInstances, FlexibleContexts, RankNTypes #-} module Control.Monad.Memo.Class ( MonadCache(..), MonadMemo(..), for2, for3, for4, memoln, memol0, memol1, memol2, memol3, memol4, ) where import Data.Function import Data.Maybe import Data.Either import Data.Monoid import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.Cont import Control.Monad.Trans.Error import Control.Monad.Trans.Identity import Control.Monad.Trans.List import Control.Monad.Trans.Maybe import Control.Monad.Trans.Reader import qualified Control.Monad.Trans.State.Lazy as Lazy -- (StateT, get, put) import qualified Control.Monad.Trans.State.Strict as Strict -- (StateT, get, put) import Control.Monad.Trans.Writer.Lazy as Lazy import Control.Monad.Trans.Writer.Strict as Strict import Control.Arrow import Prelude (undefined) class Monad m => MonadCache k v m | m -> k, m -> v where lookup :: k -> m (Maybe v) add :: k -> v -> m () class Monad m => MonadMemo k v m | m -> k, m -> v where memo :: (k -> m v) -> k -> m v memoln :: (MonadCache k2 v m1, Monad m1, Monad m2) => (forall a.m1 a -> m2 a) -> (k1 -> k2) -> (k1 -> m2 v) -> k1 -> m2 v memoln fl fk f k = do mr <- fl $ lookup (fk k) case mr of Just r -> return r Nothing -> do r <- f k fl $ add (fk k) r return r memoln2 :: (MonadCache k v m1, Monad m1, Monad m2) => (forall a.m1 a -> m2 a) -> (k -> m2 v) -> k -> m2 v memoln2 fl f k = do mr <- fl $ lookup k case mr of Just r -> return r Nothing -> do r <- f k fl $ add k r return r memov2 f k = do mr <- lookup k case mr of Just r -> return r Nothing -> do r <- f k add k r return r -- | Adapter for memoization of two-argument function for2 :: (((k1, k2) -> mv) -> (k1, k2) -> mv) -> (k1 -> k2 -> mv) -> k1 -> k2 -> mv for2 m f a b = m (\(a,b) -> f a b) (a,b) -- | Adapter for memoization of three-argument function for3 :: (((k1, k2, k3) -> mv) -> (k1, k2, k3) -> mv) -> (k1 -> k2 -> k3 -> mv) -> k1 -> k2 -> k3 -> mv for3 m f a b c = m (\(a,b,c) -> f a b c) (a,b,c) -- | Adapter for memoization of four-argument function for4 :: (((k1, k2, k3, k4) -> mv) -> (k1, k2, k3, k4) -> mv) -> (k1 -> k2 -> k3 -> k4 -> mv) -> k1 -> k2 -> k3 -> k4 -> mv for4 m f a b c d = m (\(a,b,c,d) -> f a b c d) (a,b,c,d) -- | Uses current monad's memoization cache memol0 :: (MonadCache k v m, Monad m) => (k -> m v) -> k -> m v memol0 = memoln2 id -- | Uses the 1st transformer in stack for memoization cache memol1 :: (MonadTrans t1, MonadCache k v m, Monad (t1 m)) => (k -> t1 m v) -> k -> t1 m v memol1 = memoln2 lift -- | Uses the 2nd transformer in stack for memoization cache memol2 :: (MonadTrans t1, MonadTrans t2, MonadCache k v m, Monad (t2 m), Monad (t1 (t2 m))) => (k -> t1 (t2 m) v) -> k -> t1 (t2 m) v memol2 = memoln (lift . lift) id -- | Uses the 3rd transformer in stack for memoization cache memol3 :: (MonadTrans t1, MonadTrans t2, MonadTrans t3, MonadCache k v m, Monad (t3 m), Monad (t2 (t3 m)), Monad (t1 (t2 (t3 m))) ) => (k -> t1 (t2 (t3 m)) v) -> k -> t1 (t2 (t3 m)) v memol3 = memoln (lift.lift.lift) id -- | Uses the 4th transformer in stack for memoization cache memol4 :: (MonadTrans t1, MonadTrans t2, MonadTrans t3, MonadTrans t4, MonadCache k v m, Monad (t4 m), Monad (t3 (t4 m)), Monad (t2 (t3 (t4 m))), Monad (t1 (t2 (t3 (t4 m)))) ) => (k -> t1 (t2 (t3 (t4 m))) v) -> k -> t1 (t2 (t3 (t4 m))) v --memol4 :: (MonadTrans t4, MonadCache k v m, Monad (t1 (t2 (t3 (t4 m))))) => (k -> t1 (t2 (t3 (t4 m))) v) -> k -> t1 (t2 (t3 (t4 m))) v memol4 = memoln (lift.lift.lift.lift) id instance (MonadCache k v m) => MonadMemo k v (IdentityT m) where memo f = IdentityT . memol0 (runIdentityT . f) instance (MonadCache k v m) => MonadMemo k v (ContT r m) where memo = memol1 instance (MonadCache k (Maybe v) m) => MonadMemo k v (MaybeT m) where memo f = MaybeT . memol0 (runMaybeT . f) instance (MonadCache k [v] m) => MonadMemo k v (ListT m) where memo f = ListT . memol0 (runListT . f) instance (Error e, MonadCache k (Either e v) m) => MonadMemo k v (ErrorT e m) where memo f = ErrorT . memol0 (runErrorT . f) instance (MonadCache (r,k) v m) => MonadMemo k v (ReaderT r m) where memo f k = do e <- ask memoln lift (e,) f k instance (Monoid w, MonadCache k (v,w) m) => MonadMemo k v (Lazy.WriterT w m) where memo f = Lazy.WriterT . memol0 (Lazy.runWriterT . f) instance (Monoid w, MonadCache k (v,w) m) => MonadMemo k v (Strict.WriterT w m) where memo f = Strict.WriterT . memol0 (Strict.runWriterT . f) instance (MonadCache (s,k) v m) => MonadMemo k v (Lazy.StateT s m) where memo f k = do s <- Lazy.get memoln lift (s,) f k instance (MonadCache (s,k) v m) => MonadMemo k v (Strict.StateT s m) where memo f k = do s <- Strict.get memoln lift (s,) f k