module Data.Comp.Param.Ditraversable
    (
     Ditraversable(..)
    ) where
import Prelude hiding (mapM, sequence, foldr)
import Data.Maybe (fromJust)
import Data.Comp.Param.Any
import Data.Comp.Param.Difunctor
import Test.QuickCheck.Gen
import Data.Functor.Identity
import Control.Monad.Reader hiding (mapM, sequence)
import Control.Monad.Error hiding (mapM, sequence)
import Control.Monad.State hiding (mapM, sequence)
import Control.Monad.List hiding (mapM, sequence)
import Control.Monad.RWS hiding (Any, mapM, sequence)
import Control.Monad.Writer hiding (Any, mapM, sequence)
class (Difunctor f, Monad m) => Ditraversable f m a where
    dimapM :: (b -> m c) -> f a b -> m (f a c)
    dimapM f = disequence . fmap f
    disequence :: f a (m b) -> m (f a b)
    disequence = dimapM id
instance Ditraversable (->) Gen a where
    dimapM f s = MkGen run
        where run stdGen seed a = unGen (f (s a)) stdGen seed
    disequence s = MkGen run
        where run stdGen seed a = unGen (s a) stdGen seed
instance Ditraversable (->) Identity a where
    dimapM f s = Identity run
        where run a = runIdentity (f (s a))
    disequence s = Identity run
        where run a = runIdentity (s a)
instance Ditraversable (->) m a =>  Ditraversable (->) (ReaderT r m) a where
    dimapM f s = ReaderT (disequence . run)
        where run r a = runReaderT (f (s a)) r
    disequence s = ReaderT (disequence . run)
        where run r a = runReaderT (s a) r
instance Ditraversable (->) Maybe Any where
    dimapM f g = disequence (f .g)
    disequence f = do _ <- f undefined
                      return $ \x -> fromJust $ f x
instance Ditraversable (->) (Either e) Any where
    dimapM f g = disequence (f . g)
    disequence h = case h undefined of
                   Left e -> Left e
                   Right _ -> Right $ fromRight . h
        where fromRight (Right x) = x
              fromRight (Left _) = error "fromRight: expected Right"
instance (Error e, Ditraversable (->) m Any) => Ditraversable (->) (ErrorT e m) Any where
    dimapM f g = disequence (f . g)
    disequence h = ErrorT $
                 do r <- runErrorT (h undefined) 
                    case r of
                      Left e -> return $ Left e
                      Right _ -> liftM Right $ disequence (liftM fromRight . runErrorT . h) 
        where fromRight (Right x) = x
              fromRight (Left _) = error "fromRight: expected Right"
instance (Ditraversable (->) m Any) => Ditraversable (->) (StateT s m) Any where
    dimapM f g = disequence (f . g)
    disequence h = StateT trans
        where trans s = 
                  do (_,s') <- runStateT (h undefined) s
                     fun <-  disequence (liftM fst . (`runStateT` s) . h)
                     return (fun,s')
instance (Monoid w, Ditraversable (->) m Any) => Ditraversable (->) (WriterT w m) Any where
    dimapM f g = disequence (f . g)
    disequence h = WriterT trans
        where trans = 
                  do (_,w) <- runWriterT (h undefined)
                     fun <-  disequence (liftM fst . runWriterT . h)
                     return (fun,w)
instance Ditraversable (->) [] Any where 
    dimapM f g = disequence (f . g)
    disequence h = run (h undefined) 0
        where run [] _ = []
              run (_ : xs) i = let f a = h a !! i
                               in f : run xs (i+1)
instance Ditraversable (->) m Any =>  Ditraversable (->) (ListT m) Any where 
    dimapM f g = disequence (f . g)
    disequence h = ListT $ (`run`  0) =<< runListT (h undefined)
        where run [] _ = return []
              run (_ : xs) i = do f <- disequence $ liftM (!! i) . runListT . h
                                  liftM (f :) (run xs (i+1))
instance (Monoid w, Ditraversable (->) m Any) => Ditraversable (->) (RWST r w s m) Any where
    dimapM f g = disequence (f . g)
    disequence h = RWST trans
        where trans r s = 
                  do (_,s',w) <- runRWST (h undefined) r s
                     fun <-  disequence (liftM fst' . (\ m -> runRWST m r s) . h)
                     return (fun,s',w)
              fst' (x,_,_) = x