{-# LANGUAGE RankNTypes, FlexibleInstances, MultiParamTypeClasses, FlexibleContexts, OverlappingInstances #-} -------------------------------------------------------------------------------- -- | -- Module : Data.Comp.Param.Ditraversable -- Copyright : (c) 2011 Patrick Bahr, Tom Hvitved -- License : BSD3 -- Maintainer : Tom Hvitved -- Stability : experimental -- Portability : non-portable (GHC Extensions) -- -- This module defines traversable difunctors. -- -------------------------------------------------------------------------------- module Data.Comp.Param.Ditraversable ( Ditraversable(..) ) where import Prelude hiding (mapM, sequence, foldr) import Data.Maybe (fromJust) import Data.Comp.Param.Any import Data.Comp.Param.Difunctor import Test.QuickCheck.Gen import Data.Functor.Identity import Control.Monad.Reader hiding (mapM, sequence) import Control.Monad.Error hiding (mapM, sequence) import Control.Monad.State hiding (mapM, sequence) import Control.Monad.List hiding (mapM, sequence) import Control.Monad.RWS hiding (Any, mapM, sequence) import Control.Monad.Writer hiding (Any, mapM, sequence) {-| Difunctors representing data structures that can be traversed from left to right. -} class (Difunctor f, Monad m) => Ditraversable f m a where dimapM :: (b -> m c) -> f a b -> m (f a c) dimapM f = disequence . fmap f disequence :: f a (m b) -> m (f a b) disequence = dimapM id instance Ditraversable (->) Gen a where dimapM f s = MkGen run where run stdGen seed a = unGen (f (s a)) stdGen seed disequence s = MkGen run where run stdGen seed a = unGen (s a) stdGen seed instance Ditraversable (->) Identity a where dimapM f s = Identity run where run a = runIdentity (f (s a)) disequence s = Identity run where run a = runIdentity (s a) instance Ditraversable (->) m a => Ditraversable (->) (ReaderT r m) a where dimapM f s = ReaderT (disequence . run) where run r a = runReaderT (f (s a)) r disequence s = ReaderT (disequence . run) where run r a = runReaderT (s a) r {-| Functions of the type @Any -> Maybe a@ can be turned into functions of type @Maybe (Any -> a)@. The empty type @Any@ ensures that the function is parametric in the input, and hence the @Maybe@ monad can be pulled out. -} instance Ditraversable (->) Maybe Any where dimapM f g = disequence (f .g) disequence f = do _ <- f undefined return $ \x -> fromJust $ f x instance Ditraversable (->) (Either e) Any where dimapM f g = disequence (f . g) disequence h = case h undefined of Left e -> Left e Right _ -> Right $ fromRight . h where fromRight (Right x) = x fromRight (Left _) = error "fromRight: expected Right" instance (Error e, Ditraversable (->) m Any) => Ditraversable (->) (ErrorT e m) Any where dimapM f g = disequence (f . g) disequence h = ErrorT $ do r <- runErrorT (h undefined) case r of Left e -> return $ Left e Right _ -> liftM Right $ disequence (liftM fromRight . runErrorT . h) where fromRight (Right x) = x fromRight (Left _) = error "fromRight: expected Right" instance (Ditraversable (->) m Any) => Ditraversable (->) (StateT s m) Any where dimapM f g = disequence (f . g) disequence h = StateT trans where trans s = do (_,s') <- runStateT (h undefined) s fun <- disequence (liftM fst . (`runStateT` s) . h) return (fun,s') instance (Monoid w, Ditraversable (->) m Any) => Ditraversable (->) (WriterT w m) Any where dimapM f g = disequence (f . g) disequence h = WriterT trans where trans = do (_,w) <- runWriterT (h undefined) fun <- disequence (liftM fst . runWriterT . h) return (fun,w) instance Ditraversable (->) [] Any where dimapM f g = disequence (f . g) disequence h = run (h undefined) 0 where run [] _ = [] run (_ : xs) i = let f a = h a !! i in f : run xs (i+1) instance Ditraversable (->) m Any => Ditraversable (->) (ListT m) Any where dimapM f g = disequence (f . g) disequence h = ListT $ (`run` 0) =<< runListT (h undefined) where run [] _ = return [] run (_ : xs) i = do f <- disequence $ liftM (!! i) . runListT . h liftM (f :) (run xs (i+1)) instance (Monoid w, Ditraversable (->) m Any) => Ditraversable (->) (RWST r w s m) Any where dimapM f g = disequence (f . g) disequence h = RWST trans where trans r s = do (_,s',w) <- runRWST (h undefined) r s fun <- disequence (liftM fst' . (\ m -> runRWST m r s) . h) return (fun,s',w) fst' (x,_,_) = x