module Control.Monad.Bi (
MonadBi(..),
lazyIO,
collect,
collectN,
) where
import "mtl" Control.Monad.Reader (ReaderT, runReaderT, ask)
import "mtl" Control.Monad.State (StateT, runStateT, get, MonadIO, liftIO)
import "mtl" Control.Monad.Trans (lift)
import Control.Monad (liftM, liftM2, join)
import System.IO.Unsafe (unsafeInterleaveIO)
class (Monad m1, Monad m2) => MonadBi m1 m2 where
raise :: m2 a -> m1 a
lower :: m1 a -> m1 (m2 a)
instance Monad m => MonadBi m m where
raise = id
lower = return
raiseVia :: (MonadBi m1 m2, MonadBi m2 m3) => m2 a -> (m3 a -> m1 a)
raiseVia via = raise . (flip asTypeOf) via . raise
lowerVia :: (MonadBi m1 m2, MonadBi m2 m3) => m2 a -> (m1 a -> m1 (m3 a))
lowerVia via = join . liftM (raise . lower . (flip asTypeOf) via) . lower
instance (Monad m) => MonadBi (StateT s m) m where
raise = lift
lower m = get >>= return . fmap' fst . runStateT m
where fmap' f x = x >>= return . f
instance Monad m => MonadBi (ReaderT c m) m where
raise = lift
lower m = ask >>= return . runReaderT m
instance (Monad m) => MonadBi (StateT s (ReaderT c m)) m where
raise = raiseVia (undefined :: ReaderT c m a)
lower = lowerVia (undefined :: ReaderT c m a)
lazyIO :: (MonadBi m IO) => m a -> m a
lazyIO = join . liftM (raise . unsafeInterleaveIO) . lower
collect :: (MonadBi m IO) => m a -> (a -> m b) -> m [b]
collect m f = let h = m >>= \a -> liftM2 (:) (f a) (lazyIO h) in h
collectN :: (MonadBi m IO) => Int -> m a -> (a -> m b) -> m [b]
collectN n m f = liftM (take n) (collect m f)