{-# LANGUAGE RankNTypes #-}
module Data.Conduit.Lift (
    
    exceptC,
    runExceptC,
    catchExceptC,
    
    runCatchC,
    catchCatchC,
    
    maybeC,
    runMaybeC,
    
    readerC,
    runReaderC,
    
    stateLC,
    runStateLC,
    evalStateLC,
    execStateLC,
    
    stateC,
    runStateC,
    evalStateC,
    execStateC,
    
    writerLC,
    runWriterLC,
    execWriterLC,
    
    writerC,
    runWriterC,
    execWriterC,
    
    rwsLC,
    runRWSLC,
    evalRWSLC,
    execRWSLC,
    
    rwsC,
    runRWSC,
    evalRWSC,
    execRWSC
    ) where
import Data.Conduit
import Data.Conduit.Internal (ConduitT (..), Pipe (..))
import Control.Monad.Trans.Class (MonadTrans(..))
import Data.Monoid (Monoid(..))
import qualified Control.Monad.Trans.Except as Ex
import qualified Control.Monad.Trans.Maybe as M
import qualified Control.Monad.Trans.Reader as R
import qualified Control.Monad.Trans.State.Strict as SS
import qualified Control.Monad.Trans.Writer.Strict as WS
import qualified Control.Monad.Trans.RWS.Strict as RWSS
import qualified Control.Monad.Trans.State.Lazy as SL
import qualified Control.Monad.Trans.Writer.Lazy as WL
import qualified Control.Monad.Trans.RWS.Lazy as RWSL
import Control.Monad.Catch.Pure (CatchT (runCatchT))
import Control.Exception (SomeException)
exceptC
  :: Monad m =>
     ConduitT i o m (Either e a) -> ConduitT i o (Ex.ExceptT e m) a
exceptC p = do
    x <- transPipe lift p
    lift $ Ex.ExceptT (return x)
runExceptC
  :: Monad m =>
     ConduitT i o (Ex.ExceptT e m) r -> ConduitT i o m (Either e r)
runExceptC (ConduitT c0) =
    ConduitT $ \rest ->
        let go (Done r) = rest (Right r)
            go (PipeM mp) = PipeM $ do
                eres <- Ex.runExceptT mp
                return $ case eres of
                    Left e -> rest $ Left e
                    Right p -> go p
            go (Leftover p i) = Leftover (go p) i
            go (HaveOutput p o) = HaveOutput (go p) o
            go (NeedInput x y) = NeedInput (go . x) (go . y)
         in go (c0 Done)
{-# INLINABLE runExceptC #-}
catchExceptC
  :: Monad m =>
     ConduitT i o (Ex.ExceptT e m) r
     -> (e -> ConduitT i o (Ex.ExceptT e m) r)
     -> ConduitT i o (Ex.ExceptT e m) r
catchExceptC c0 h =
    ConduitT $ \rest ->
        let go (Done r) = rest r
            go (PipeM mp) = PipeM $ do
                eres <- lift $ Ex.runExceptT mp
                return $ case eres of
                    Left e -> unConduitT (h e) rest
                    Right p -> go p
            go (Leftover p i) = Leftover (go p) i
            go (HaveOutput p o) = HaveOutput (go p) o
            go (NeedInput x y) = NeedInput (go . x) (go . y)
         in go $ unConduitT c0 Done
  where
{-# INLINABLE catchExceptC #-}
runCatchC
  :: Monad m =>
     ConduitT i o (CatchT m) r -> ConduitT i o m (Either SomeException r)
runCatchC c0 =
    ConduitT $ \rest ->
        let go (Done r) = rest (Right r)
            go (PipeM mp) = PipeM $ do
                eres <- runCatchT mp
                return $ case eres of
                    Left e -> rest $ Left e
                    Right p -> go p
            go (Leftover p i) = Leftover (go p) i
            go (HaveOutput p o) = HaveOutput (go p) o
            go (NeedInput x y) = NeedInput (go . x) (go . y)
         in go $ unConduitT c0 Done
{-# INLINABLE runCatchC #-}
catchCatchC
  :: Monad m
  => ConduitT i o (CatchT m) r
  -> (SomeException -> ConduitT i o (CatchT m) r)
  -> ConduitT i o (CatchT m) r
catchCatchC (ConduitT c0) h =
    ConduitT $ \rest ->
        let go (Done r) = rest r
            go (PipeM mp) = PipeM $ do
                eres <- lift $ runCatchT mp
                return $ case eres of
                    Left e -> unConduitT (h e) rest
                    Right p -> go p
            go (Leftover p i) = Leftover (go p) i
            go (HaveOutput p o) = HaveOutput (go p) o
            go (NeedInput x y) = NeedInput (go . x) (go . y)
         in go (c0 Done)
{-# INLINABLE catchCatchC #-}
maybeC
  :: Monad m =>
     ConduitT i o m (Maybe a) -> ConduitT i o (M.MaybeT m) a
maybeC p = do
    x <- transPipe lift p
    lift $ M.MaybeT (return x)
{-# INLINABLE maybeC #-}
runMaybeC
  :: Monad m =>
     ConduitT i o (M.MaybeT m) r -> ConduitT i o m (Maybe r)
runMaybeC (ConduitT c0) =
    ConduitT $ \rest ->
        let go (Done r) = rest (Just r)
            go (PipeM mp) = PipeM $ do
                mres <- M.runMaybeT mp
                return $ case mres of
                    Nothing -> rest Nothing
                    Just p -> go p
            go (Leftover p i) = Leftover (go p) i
            go (HaveOutput p o) = HaveOutput (go p) o
            go (NeedInput x y) = NeedInput (go . x) (go . y)
         in go (c0 Done)
{-# INLINABLE runMaybeC #-}
readerC
  :: Monad m =>
     (r -> ConduitT i o m a) -> ConduitT i o (R.ReaderT r m) a
readerC k = do
    i <- lift R.ask
    transPipe lift (k i)
{-# INLINABLE readerC #-}
runReaderC
  :: Monad m =>
     r -> ConduitT i o (R.ReaderT r m) res -> ConduitT i o m res
runReaderC r = transPipe (`R.runReaderT` r)
{-# INLINABLE runReaderC #-}
stateLC
  :: Monad m =>
     (s -> ConduitT i o m (a, s)) -> ConduitT i o (SL.StateT s m) a
stateLC k = do
    s <- lift SL.get
    (r, s') <- transPipe lift (k s)
    lift (SL.put s')
    return r
{-# INLINABLE stateLC #-}
thread :: Monad m
       => (r -> s -> res)
       -> (forall a. t m a -> s -> m (a, s))
       -> s
       -> ConduitT i o (t m) r
       -> ConduitT i o m res
thread toRes runM s0 (ConduitT c0) =
    ConduitT $ \rest ->
        let go s (Done r) = rest (toRes r s)
            go s (PipeM mp) = PipeM $ do
                (p, s') <- runM mp s
                return $ go s' p
            go s (Leftover p i) = Leftover (go s p) i
            go s (NeedInput x y) = NeedInput (go s . x) (go s . y)
            go s (HaveOutput p o) = HaveOutput (go s p) o
         in go s0 (c0 Done)
{-# INLINABLE thread #-}
runStateLC
  :: Monad m =>
     s -> ConduitT i o (SL.StateT s m) r -> ConduitT i o m (r, s)
runStateLC = thread (,) SL.runStateT
{-# INLINABLE runStateLC #-}
evalStateLC
  :: Monad m =>
     s -> ConduitT i o (SL.StateT s m) r -> ConduitT i o m r
evalStateLC s p = fmap fst $ runStateLC s p
{-# INLINABLE evalStateLC #-}
execStateLC
  :: Monad m =>
     s -> ConduitT i o (SL.StateT s m) r -> ConduitT i o m s
execStateLC s p = fmap snd $ runStateLC s p
{-# INLINABLE execStateLC #-}
stateC
  :: Monad m =>
     (s -> ConduitT i o m (a, s)) -> ConduitT i o (SS.StateT s m) a
stateC k = do
    s <- lift SS.get
    (r, s') <- transPipe lift (k s)
    lift (SS.put s')
    return r
{-# INLINABLE stateC #-}
runStateC
  :: Monad m =>
     s -> ConduitT i o (SS.StateT s m) r -> ConduitT i o m (r, s)
runStateC = thread (,) SS.runStateT
{-# INLINABLE runStateC #-}
evalStateC
  :: Monad m =>
     s -> ConduitT i o (SS.StateT s m) r -> ConduitT i o m r
evalStateC s p = fmap fst $ runStateC s p
{-# INLINABLE evalStateC #-}
execStateC
  :: Monad m =>
     s -> ConduitT i o (SS.StateT s m) r -> ConduitT i o m s
execStateC s p = fmap snd $ runStateC s p
{-# INLINABLE execStateC #-}
writerLC
  :: (Monad m, Monoid w) =>
     ConduitT i o m (b, w) -> ConduitT i o (WL.WriterT w m) b
writerLC p = do
    (r, w) <- transPipe lift p
    lift $ WL.tell w
    return r
{-# INLINABLE writerLC #-}
runWriterLC
  :: (Monad m, Monoid w) =>
     ConduitT i o (WL.WriterT w m) r -> ConduitT i o m (r, w)
runWriterLC = thread (,) run mempty
  where
    run m w = do
        (a, w') <- WL.runWriterT m
        return (a, w `mappend` w')
{-# INLINABLE runWriterLC #-}
execWriterLC
  :: (Monad m, Monoid w) =>
     ConduitT i o (WL.WriterT w m) r -> ConduitT i o m w
execWriterLC p = fmap snd $ runWriterLC p
{-# INLINABLE execWriterLC #-}
writerC
  :: (Monad m, Monoid w) =>
     ConduitT i o m (b, w) -> ConduitT i o (WS.WriterT w m) b
writerC p = do
    (r, w) <- transPipe lift p
    lift $ WS.tell w
    return r
{-# INLINABLE writerC #-}
runWriterC
  :: (Monad m, Monoid w) =>
     ConduitT i o (WS.WriterT w m) r -> ConduitT i o m (r, w)
runWriterC = thread (,) run mempty
  where
    run m w = do
        (a, w') <- WS.runWriterT m
        return (a, w `mappend` w')
{-# INLINABLE runWriterC #-}
execWriterC
  :: (Monad m, Monoid w) =>
     ConduitT i o (WS.WriterT w m) r -> ConduitT i o m w
execWriterC p = fmap snd $ runWriterC p
{-# INLINABLE execWriterC #-}
rwsLC
  :: (Monad m, Monoid w) =>
     (r -> s -> ConduitT i o m (a, s, w)) -> ConduitT i o (RWSL.RWST r w s m) a
rwsLC k = do
    i <- lift RWSL.ask
    s <- lift RWSL.get
    (r, s', w) <- transPipe lift (k i s)
    lift $ do
        RWSL.put s'
        RWSL.tell w
    return r
{-# INLINABLE rwsLC #-}
runRWSLC
  :: (Monad m, Monoid w) =>
     r
     -> s
     -> ConduitT i o (RWSL.RWST r w s m) res
     -> ConduitT i o m (res, s, w)
runRWSLC r s0 = thread toRes run (s0, mempty)
  where
    toRes a (s, w) = (a, s, w)
    run m (s, w) = do
        (res, s', w') <- RWSL.runRWST m r s
        return (res, (s', w `mappend` w'))
{-# INLINABLE runRWSLC #-}
evalRWSLC
  :: (Monad m, Monoid w) =>
     r
     -> s
     -> ConduitT i o (RWSL.RWST r w s m) res
     -> ConduitT i o m (res, w)
evalRWSLC i s p = fmap f $ runRWSLC i s p
  where f x = let (r, _, w) = x in (r, w)
{-# INLINABLE evalRWSLC #-}
execRWSLC
  :: (Monad m, Monoid w) =>
     r
     -> s
     -> ConduitT i o (RWSL.RWST r w s m) res
     -> ConduitT i o m (s, w)
execRWSLC i s p = fmap f $ runRWSLC i s p
  where f x = let (_, s2, w2) = x in (s2, w2)
{-# INLINABLE execRWSLC #-}
rwsC
  :: (Monad m, Monoid w) =>
     (r -> s -> ConduitT i o m (a, s, w)) -> ConduitT i o (RWSS.RWST r w s m) a
rwsC k = do
    i <- lift RWSS.ask
    s <- lift RWSS.get
    (r, s', w) <- transPipe lift (k i s)
    lift $ do
        RWSS.put s'
        RWSS.tell w
    return r
{-# INLINABLE rwsC #-}
runRWSC
  :: (Monad m, Monoid w) =>
     r
     -> s
     -> ConduitT i o (RWSS.RWST r w s m) res
     -> ConduitT i o m (res, s, w)
runRWSC r s0 = thread toRes run (s0, mempty)
  where
    toRes a (s, w) = (a, s, w)
    run m (s, w) = do
        (res, s', w') <- RWSS.runRWST m r s
        return (res, (s', w `mappend` w'))
{-# INLINABLE runRWSC #-}
evalRWSC
  :: (Monad m, Monoid w) =>
     r
     -> s
     -> ConduitT i o (RWSS.RWST r w s m) res
     -> ConduitT i o m (res, w)
evalRWSC i s p = fmap f $ runRWSC i s p
  where f x = let (r, _, w) = x in (r, w)
{-# INLINABLE evalRWSC #-}
execRWSC
  :: (Monad m, Monoid w) =>
     r
     -> s
     -> ConduitT i o (RWSS.RWST r w s m) res
     -> ConduitT i o m (s, w)
execRWSC i s p = fmap f $ runRWSC i s p
  where f x = let (_, s2, w2) = x in (s2, w2)
{-# INLINABLE execRWSC #-}