{-# LANGUAGE TypeFamilies #-}
-----------------------------------------------------------------------------
-- |
-- Module: Data.Enumerator.Instances.TF
-- Copyright: 2010 John Millikin
-- License: MIT
--
-- Maintainer: jmillikin@gmail.com
-- Portability: portable
--
-- Enumerator instances for monads-tf classes
--
-----------------------------------------------------------------------------
module Data.Enumerator.Instances.TF () where

import qualified Control.Exception as Exc

import Control.Monad.Trans (lift)
import qualified Control.Monad.Cont.Class as M_C
import qualified Control.Monad.Error.Class as M_E
import qualified Control.Monad.RWS.Class as M_RWS
import qualified Control.Monad.Reader.Class as M_R
import qualified Control.Monad.State.Class as M_S
import qualified Control.Monad.Writer.Class as M_W

import qualified Data.Enumerator as E
import Data.Monoid (mempty, mappend)

instance M_C.MonadCont m => M_C.MonadCont (E.Iteratee a m) where
	callCC f = E.Iteratee $ M_C.callCC $ \c -> let
		emptyYield x = E.Yield x (E.Chunks [])
		in E.runIteratee (f (E.Iteratee . c . emptyYield))

instance Monad m => M_E.MonadError (E.Iteratee a m) where
	type M_E.ErrorType (E.Iteratee a m) = Exc.SomeException
	throwError = E.throwError
	catchError = E.catchError

instance M_RWS.MonadRWS m => M_RWS.MonadRWS (E.Iteratee a m)

instance M_R.MonadReader m => M_R.MonadReader (E.Iteratee a m) where
	type M_R.EnvType (E.Iteratee a m) = M_R.EnvType m
	ask = lift M_R.ask
	local f m = E.Iteratee (M_R.local f (E.runIteratee m))

instance M_S.MonadState m =>  M_S.MonadState (E.Iteratee a m) where
	type  M_S.StateType (E.Iteratee a m) =  M_S.StateType m
	get = lift  M_S.get
	put = lift .  M_S.put

instance M_W.MonadWriter m => M_W.MonadWriter (E.Iteratee a m) where
	type M_W.WriterType (E.Iteratee a m) = M_W.WriterType m
	tell = lift . M_W.tell
	
	listen = loop mempty where
		loop w0 m = E.Iteratee $ do
			~(step, w1) <- M_W.listen (E.runIteratee m)
			let w = mappend w0 w1
			return $ case step of
				E.Yield x cs -> E.Yield (x, w) cs
				E.Error err -> E.Error err
				E.Continue k -> E.Continue (loop w . k)
	
	pass m = E.Iteratee $ M_W.pass $ do
		step <- E.runIteratee m
		case step of
			E.Yield (x, wf) cs -> return (E.Yield x cs, wf)
			E.Error err -> return (E.Error err, id)
			E.Continue k -> return (E.Continue (M_W.pass . k), id)