{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}

module Control.Monad.MultiWrap (-- * MultiWrap class
                                MultiWrap(..)
                               -- * Example
                               -- $example
                               )where

import Control.Monad.Trans (MonadTrans(..), MonadIO(..))
import Control.Monad.Wrap

-- | 'MultiWrap' is like 'MonadWrapIO', but for monads created by
-- 'MonadTrans' transformers.  This is useful, if, for example, you
-- implement your own monad, @ReaderTLike@, that is like a @ReaderT@
-- except that you don't want to make it a member of the @MonadReader@
-- class because you are already using @MonadReader@ for some
-- different state (or because you are implementing a library and the
-- user of the library should be free to use @ReaderT@).
--
-- As long as @ReaderTLike@ is a member of the 'MonadTrans' class and
-- assuming you have a @localLike@ function equivalent to 'local', you
-- should be able to run things like:
--
-- >    mwrap (localLike modifyConfig :: ReaderTLike IO a -> ReaderTLike IO a)
-- >          someComputation
--
-- You will generally have to specify the type of the wrap function or
-- computation explicitly, but as long as you specify the type,
-- 'mwrap' saves you from keeping track of how many nested levels of
-- transformer you have and from having to invoke 'wrap' repeatedly.
--
-- Note one difference from 'MonadWrap' and 'MonadWrapIO' is that
-- 'mresult' and 'mresultF' require an extra argument so as to specify
-- the inner monad in which you want to supply the result.  (E.g., in
-- the case of using 'catch' to produce a different return value in
-- case of exceptions, the inner monad would be 'IO', and the extra
-- argument might be supplied as @(undefined :: 'IO' Type)@.
--
-- Note that 'MultiWrap' only works for up to 9 levels of nested
-- monad transformer.
class (Monad mOut) => MultiWrap mIn mOut a r | mIn mOut a -> r where
    mwrap :: (mIn r -> mIn r) -> mOut a -> mOut a
    mresultF :: mIn b
                -- ^ This argumet is here just for the type, because
                -- otherwise the @resultTrans@ has no way of knowing
                -- which inner monad you want.  The value of this
                -- argument is ignored, so it is safe to use
                -- @(undefined :: InnerMonad ())@ just as a way of
                -- specifying the type.
             -> mOut (a -> r)
    mresult :: mIn b -> a -> mOut r
    mresult b a = mresultF b >>= return . ($ a)

instance (Monad m) => MultiWrap m m a a where
    mwrap = ($)
    mresultF _ = return id

{- $example

> module Main where
> 
> import Control.Monad.MultiLift
> import Control.Monad.MultiWrap
> import Control.Monad.Reader
> import Control.Monad.State
> import Control.Monad.Trans
> import Control.Monad.Wrap
> 
> newtype Type1 = Type1 { unType1 :: String }
> type Reader1 = ReaderT Type1 IO
> 
> newtype Type2 = Type2 { unType2 :: String }
> type Reader2 = ReaderT Type2 Reader1
> 
> type Outer = StateT () Reader2
> 
> r3 :: Outer ()
> r3 = do
>   -- Note that you have to specify the inner type
>   s1 <- mlift (asks unType1 :: Reader1 String)
>   liftIO $ putStrLn $ "s1: " ++ s1
>   s2 <- mlift (asks unType2 :: Reader2 String)
>   liftIO $ putStrLn $ "s2: " ++ s2
> 
> r2 :: Outer ()
> r2 = do
>   mwrap (local augment :: Reader1 a -> Reader1 a) r3
>   where
>     augment (Type1 s) = Type1 $ s ++ " (augmented)"
> 
> r1 :: Reader2 ()
> r1 = do
>   liftM fst $ runStateT r3 ()
>   liftM fst $ runStateT r2 ()
>   -- runContWrapT r2 return
>   
> 
> main :: IO ()
> main = do
>   runReaderT (runReaderT r1 $ Type2 "this is the Reader2 contents")
>        $ Type1 "this is the Reader1 contents"

-}
instance (MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t1 m) a1 a0 where
    mwrap = wrap
    mresultF _ = do f1 <- resultF
                    return $ f1

instance (MonadTrans t2, Monad (t2 (t1 m)), MonadWrap t2 a2 a1, MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t2 (t1 m)) a2 a0 where
    mwrap = wrap . wrap
    mresultF _ = do f1 <- lift resultF
                    f2 <- resultF
                    return $ f1 . f2

instance (MonadTrans t3, Monad (t3 (t2 (t1 m))), MonadWrap t3 a3 a2, MonadTrans t2, Monad (t2 (t1 m)), MonadWrap t2 a2 a1, MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t3 (t2 (t1 m))) a3 a0 where
    mwrap = wrap . wrap . wrap
    mresultF _ = do f1 <- lift $ lift resultF
                    f2 <- lift resultF
                    f3 <- resultF
                    return $ f1 . f2 . f3

instance (MonadTrans t4, Monad (t4 (t3 (t2 (t1 m)))), MonadWrap t4 a4 a3, MonadTrans t3, Monad (t3 (t2 (t1 m))), MonadWrap t3 a3 a2, MonadTrans t2, Monad (t2 (t1 m)), MonadWrap t2 a2 a1, MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t4 (t3 (t2 (t1 m)))) a4 a0 where
    mwrap = wrap . wrap . wrap . wrap
    mresultF _ = do f1 <- lift $ lift $ lift resultF
                    f2 <- lift $ lift resultF
                    f3 <- lift resultF
                    f4 <- resultF
                    return $ f1 . f2 . f3 . f4

instance (MonadTrans t5, Monad (t5 (t4 (t3 (t2 (t1 m))))), MonadWrap t5 a5 a4, MonadTrans t4, Monad (t4 (t3 (t2 (t1 m)))), MonadWrap t4 a4 a3, MonadTrans t3, Monad (t3 (t2 (t1 m))), MonadWrap t3 a3 a2, MonadTrans t2, Monad (t2 (t1 m)), MonadWrap t2 a2 a1, MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t5 (t4 (t3 (t2 (t1 m))))) a5 a0 where
    mwrap = wrap . wrap . wrap . wrap . wrap
    mresultF _ = do f1 <- lift $ lift $ lift $ lift resultF
                    f2 <- lift $ lift $ lift resultF
                    f3 <- lift $ lift resultF
                    f4 <- lift resultF
                    f5 <- resultF
                    return $ f1 . f2 . f3 . f4 . f5

instance (MonadTrans t6, Monad (t6 (t5 (t4 (t3 (t2 (t1 m)))))), MonadWrap t6 a6 a5, MonadTrans t5, Monad (t5 (t4 (t3 (t2 (t1 m))))), MonadWrap t5 a5 a4, MonadTrans t4, Monad (t4 (t3 (t2 (t1 m)))), MonadWrap t4 a4 a3, MonadTrans t3, Monad (t3 (t2 (t1 m))), MonadWrap t3 a3 a2, MonadTrans t2, Monad (t2 (t1 m)), MonadWrap t2 a2 a1, MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t6 (t5 (t4 (t3 (t2 (t1 m)))))) a6 a0 where
    mwrap = wrap . wrap . wrap . wrap . wrap . wrap
    mresultF _ = do f1 <- lift $ lift $ lift $ lift $ lift resultF
                    f2 <- lift $ lift $ lift $ lift resultF
                    f3 <- lift $ lift $ lift resultF
                    f4 <- lift $ lift resultF
                    f5 <- lift resultF
                    f6 <- resultF
                    return $ f1 . f2 . f3 . f4 . f5 . f6

instance (MonadTrans t7, Monad (t7 (t6 (t5 (t4 (t3 (t2 (t1 m))))))), MonadWrap t7 a7 a6, MonadTrans t6, Monad (t6 (t5 (t4 (t3 (t2 (t1 m)))))), MonadWrap t6 a6 a5, MonadTrans t5, Monad (t5 (t4 (t3 (t2 (t1 m))))), MonadWrap t5 a5 a4, MonadTrans t4, Monad (t4 (t3 (t2 (t1 m)))), MonadWrap t4 a4 a3, MonadTrans t3, Monad (t3 (t2 (t1 m))), MonadWrap t3 a3 a2, MonadTrans t2, Monad (t2 (t1 m)), MonadWrap t2 a2 a1, MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t7 (t6 (t5 (t4 (t3 (t2 (t1 m))))))) a7 a0 where
    mwrap = wrap . wrap . wrap . wrap . wrap . wrap . wrap
    mresultF _ = do f1 <- lift $ lift $ lift $ lift $ lift $ lift resultF
                    f2 <- lift $ lift $ lift $ lift $ lift resultF
                    f3 <- lift $ lift $ lift $ lift resultF
                    f4 <- lift $ lift $ lift resultF
                    f5 <- lift $ lift resultF
                    f6 <- lift resultF
                    f7 <- resultF
                    return $ f1 . f2 . f3 . f4 . f5 . f6 . f7

instance (MonadTrans t8, Monad (t8 (t7 (t6 (t5 (t4 (t3 (t2 (t1 m)))))))), MonadWrap t8 a8 a7, MonadTrans t7, Monad (t7 (t6 (t5 (t4 (t3 (t2 (t1 m))))))), MonadWrap t7 a7 a6, MonadTrans t6, Monad (t6 (t5 (t4 (t3 (t2 (t1 m)))))), MonadWrap t6 a6 a5, MonadTrans t5, Monad (t5 (t4 (t3 (t2 (t1 m))))), MonadWrap t5 a5 a4, MonadTrans t4, Monad (t4 (t3 (t2 (t1 m)))), MonadWrap t4 a4 a3, MonadTrans t3, Monad (t3 (t2 (t1 m))), MonadWrap t3 a3 a2, MonadTrans t2, Monad (t2 (t1 m)), MonadWrap t2 a2 a1, MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t8 (t7 (t6 (t5 (t4 (t3 (t2 (t1 m)))))))) a8 a0 where
    mwrap = wrap . wrap . wrap . wrap . wrap . wrap . wrap . wrap
    mresultF _ = do f1 <- lift $ lift $ lift $ lift $ lift $ lift $ lift resultF
                    f2 <- lift $ lift $ lift $ lift $ lift $ lift resultF
                    f3 <- lift $ lift $ lift $ lift $ lift resultF
                    f4 <- lift $ lift $ lift $ lift resultF
                    f5 <- lift $ lift $ lift resultF
                    f6 <- lift $ lift resultF
                    f7 <- lift resultF
                    f8 <- resultF
                    return $ f1 . f2 . f3 . f4 . f5 . f6 . f7 . f8

instance (MonadTrans t9, Monad (t9 (t8 (t7 (t6 (t5 (t4 (t3 (t2 (t1 m))))))))), MonadWrap t9 a9 a8, MonadTrans t8, Monad (t8 (t7 (t6 (t5 (t4 (t3 (t2 (t1 m)))))))), MonadWrap t8 a8 a7, MonadTrans t7, Monad (t7 (t6 (t5 (t4 (t3 (t2 (t1 m))))))), MonadWrap t7 a7 a6, MonadTrans t6, Monad (t6 (t5 (t4 (t3 (t2 (t1 m)))))), MonadWrap t6 a6 a5, MonadTrans t5, Monad (t5 (t4 (t3 (t2 (t1 m))))), MonadWrap t5 a5 a4, MonadTrans t4, Monad (t4 (t3 (t2 (t1 m)))), MonadWrap t4 a4 a3, MonadTrans t3, Monad (t3 (t2 (t1 m))), MonadWrap t3 a3 a2, MonadTrans t2, Monad (t2 (t1 m)), MonadWrap t2 a2 a1, MonadTrans t1, Monad (t1 m), MonadWrap t1 a1 a0, Monad m) => MultiWrap m (t9 (t8 (t7 (t6 (t5 (t4 (t3 (t2 (t1 m))))))))) a9 a0 where
    mwrap = wrap . wrap . wrap . wrap . wrap . wrap . wrap . wrap . wrap
    mresultF _ = do f1 <- lift $ lift $ lift $ lift $ lift $ lift $ lift $ lift resultF
                    f2 <- lift $ lift $ lift $ lift $ lift $ lift $ lift resultF
                    f3 <- lift $ lift $ lift $ lift $ lift $ lift resultF
                    f4 <- lift $ lift $ lift $ lift $ lift resultF
                    f5 <- lift $ lift $ lift $ lift resultF
                    f6 <- lift $ lift $ lift resultF
                    f7 <- lift $ lift resultF
                    f8 <- lift resultF
                    f9 <- resultF
                    return $ f1 . f2 . f3 . f4 . f5 . f6 . f7 . f8 . f9