{-# LANGUAGE UndecidableInstances #-} module Algebra.Monad.RWS ( RWST(..),RWS,MonadInternal(..),_RWST, -- * Default methods get_,put_,modify_,local_,ask_,tell_,listen_,censor_,getAcc_ ) where import Algebra.Monad.Base newtype RWST r w s m a = RWST { runRWST :: (r,s) -> m (a,s,w) } type RWS r w s a = RWST r w s Id a -- Instances instance (Unit f,Monoid w) => Unit (RWST r w s f) where pure a = RWST (\ ~(_,s) -> pure (a,s,zero)) instance Functor f => Functor (RWST r w s f) where map f (RWST fa) = RWST (fa >>> map (\ ~(a,s,w) -> (f a,s,w))) instance (Monoid w,Monad m) => Applicative (RWST r w s m) instance (Monoid w,Monad m) => Monad (RWST r w s m) where join mm = RWST (\ ~(r,s) -> do ~(m,s',w) <- runRWST mm (r,s) ~(a,s'',w') <- runRWST m (r,s') return (a,s'',w+w')) instance (Monoid w,MonadFix m) => MonadFix (RWST r w s m) where mfix f = RWST (\x -> mfix (\ ~(a,_,_) -> runRWST (f a) x)) instance (Monoid w,MonadCont m) => MonadCont (RWST r w s m) where callCC f = RWST $ \(r,s) -> callCC $ \k -> runRWST (f (\a -> lift (k (a,s,zero)))) (r,s) deriving instance Semigroup (m (a,s,w)) => Semigroup (RWST r w s m a) deriving instance Monoid (m (a,s,w)) => Monoid (RWST r w s m a) deriving instance Semiring (m (a,s,w)) => Semiring (RWST r w s m a) deriving instance Ring (m (a,s,w)) => Ring (RWST r w s m a) instance (Monad m,Monoid w) => MonadState s (RWST r w s m) where get = RWST (\ ~(_,s) -> pure (s,s,zero) ) put s = RWST (\ _ -> pure ((),s,zero) ) modify f = RWST (\ ~(_,s) -> pure ((),f s,zero) ) instance (Monad m,Monoid w) => MonadReader r (RWST r w s m) where ask = RWST (\ ~(r,s) -> pure (r,s,zero) ) local f (RWST m) = RWST (\ ~(r,s) -> m (f r,s) ) instance (Monad m,Monoid w) => MonadWriter w (RWST r w s m) where tell w = RWST (\ ~(_,s) -> pure ((),s,w) ) listen (RWST m) = RWST (m >>> map (\ ~(a,s,w) -> ((w,a),s,w) ) ) censor (RWST m) = RWST (m >>> map (\ ~(~(a,f),s,w) -> (a,s,f w) ) ) instance Foldable m => Foldable (RWST Void w Void m) where fold (RWST m) = foldMap (\(w,_,_) -> w).m $ (zero,zero) instance Traversable m => Traversable (RWST Void w Void m) where sequence (RWST m) = map (RWST . const . map (\((s,w),a) -> (a,s,w))) . sequence . map (\(a,s,w) -> sequence ((s,w),a)) $ m (zero,zero) instance (Monoid w,MonadError e m) => MonadError e (RWST r w s m) where throw = lift.throw catch f (RWST m) = RWST (\x -> catch (flip runRWST x.f) (m x)) instance (Monoid w,MonadList m) => MonadList (RWST r w s m) where fork = lift . fork instance Monoid w => MonadTrans (RWST r w s) where lift m = RWST (\ ~(_,s) -> (,s,zero) <$> m) instance Monoid w => ConcreteMonad (RWST r w s) where generalize (RWST s) = RWST (\x -> pure (s x^..i'Id)) instance (Monoid w) => MonadInternal (RWST r w s) where internal f (RWST m) = RWST (\ x -> f (m x <&> \ ~(a,s,w) -> ((s,w),a) ) <&> \ ~((s,w),b) -> (b,s,w) ) instance (Monad m, Monoid w, MonadFuture n m) => MonadFuture n (RWST r w s m) where future = lift . future class MonadTrans t => MonadInternal t where internal :: Monad m => (forall c. m (c,a) -> m (c,b)) -> (t m a -> t m b) _RWST :: Iso (RWST r w s m a) (RWST r' w' s' m' a') ((r,s) -> m (a,s,w)) ((r',s') -> m' (a',s',w')) _RWST = iso RWST runRWST get_ :: (MonadTrans t, MonadState a m) => t m a get_ = lift get put_ :: (MonadTrans t, MonadState s m) => s -> t m () put_ = lift . put modify_ :: (MonadTrans t, MonadState s m) => (s -> s) -> t m () modify_ = lift . modify ask_ :: (MonadTrans t, MonadReader a m) => t m a ask_ = lift ask local_ :: (MonadInternal t, MonadReader r m) => (r -> r) -> t m a -> t m a local_ f = internal (local f) tell_ :: (MonadWriter w m, MonadTrans t) => w -> t m () tell_ = lift . tell listen_ :: (MonadInternal t, MonadWriter w m) => t m a -> t m (w, a) listen_ = internal (\m -> listen m <&> \(w,(c,a)) -> (c,(w,a)) ) censor_ :: (MonadInternal t, MonadWriter w m) => t m (a, w -> w) -> t m a censor_ = internal (\m -> censor (m <&> \(c,(a,f)) -> ((c,a),f))) getAcc_ :: (MonadTrans t,MonadWriterAcc w acc m) => t m acc getAcc_ = lift getAcc