{-# LANGUAGE MultiParamTypeClasses, TupleSections, Rank2Types, UndecidableInstances, FunctionalDependencies #-} module SimpleH.Monad( module SimpleH.Applicative, -- * The basic Monad interface Monad(..),MonadFix(..),MonadTrans(..), -- * Monad utilities Kleisli(..),_Kleisli, (=<<),(<=<),(>=>),(>>),(<*=),return, foldlM,foldrM,while,until, -- * Common monads -- ** The RWS Monad RWST(..),RWS, -- *** The State Monad MonadState(..), IOLens,_ioref,_mvar, StateT,State, _stateT,eval,exec,_state, (=~),(=-),gets,saving, mapAccum,mapAccum_,mapAccumR,mapAccumR_,push,pop,withPrev,withNext, -- *** The Reader monad MonadReader(..), ReaderT,Reader, _readerT,_reader, -- *** The Writer monad MonadWriter(..), WriterT,Writer, _writerT,_writer, mute,intercept, -- ** The Continuation monad MonadCont(..), ContT(..),Cont, evalContT, evalCont, -- ** The List monad MonadList(..), ListT, _listT, -- ** The Error Monad MonadError(..),try, EitherT, eitherT,runEitherT, ) where import SimpleH.Classes import SimpleH.Applicative import SimpleH.Core hiding (flip) import SimpleH.Traversable import SimpleH.Lens import qualified Control.Exception as Ex import qualified Control.Monad.Fix as Fix import Data.IORef import Control.Concurrent instance (Traversable g,Monad f,Monad g) => Monad (Compose f g) where join = Compose .map join.join.map sequence.getCompose.map getCompose -- |The class of all monads that have a fixpoint class Monad m => MonadFix m where mfix :: (a -> m a) -> m a instance MonadFix Id where mfix = cfix instance MonadFix ((->) b) where mfix = cfix instance MonadFix [] where mfix f = fix (f . head) instance MonadFix (Either e) where mfix f = fix (f . either undefined id) instance MonadFix IO where mfix = Fix.mfix instance (Contravariant f,Monad f,Traversable g,MonadFix g) => MonadFix (Compose f g) where mfix f = Compose (map mfix (collect (getCompose . f))) cfix f = map fix (collect f) mfixing f = fst<$>mfix (\ ~(_,b) -> f b ) class MonadTrans t where lift :: Monad m => m a -> t m a class MonadTrans t => MonadInternal t where internal :: Monad m => (forall c. m (c,a) -> m (c,b)) -> (t m a -> t m b) newtype Kleisli m a b = Kleisli { runKleisli :: a -> m b } instance Monad m => Category (Kleisli m) where id = Kleisli pure Kleisli f . Kleisli g = Kleisli (\a -> g a >>= f) instance Monad m => Choice (Kleisli m) where Kleisli f <|> Kleisli g = Kleisli (f <|> g) instance Monad m => Split (Kleisli m) where Kleisli f <#> Kleisli g = Kleisli (\(a,c) -> (,)<$>f a<*>g c) instance Isomorphic (a -> m b) (a -> m c) (Kleisli m a b) (Kleisli m a c) where _iso = iso Kleisli runKleisli _Kleisli = _iso :: Iso' (a -> m b) (Kleisli m a b) folding :: (Foldable t,Monoid w) => Iso' (a -> c) w -> (b -> a -> c) -> a -> t b -> c folding i f e t = at (from i) (foldMap (at i . f) t) e foldlM = folding (_Kleisli._Endo._Dual) foldrM = folding (_Kleisli._Endo) while e = fix (\w -> e >>= maybe unit (const w)) until e = fix (\w -> e >>= maybe w return) infixr 2 >>,=<< infixr 1 <*= (>>) = (*>) (=<<) = flip (>>=) f <=< g = \a -> g a >>= f (>=>) = flip (<=<) a <*= f = a >>= (>>)<$>f<*>return return = pure newtype RWST r w s m a = RWST { runRWST :: (r,s) -> m (a,s,w) } type RWS r w s a = RWST r w s Id a _RWST :: Iso' ((r,s) -> m (a,s,w)) (RWST r w s m a) _RWST = iso RWST runRWST instance (Unit f,Monoid w) => Unit (RWST r w s f) where pure a = RWST (\ ~(_,s) -> pure (a,s,zero)) instance Functor f => Functor (RWST r w s f) where map f (RWST fa) = RWST (fa >>> map (\ ~(a,s,w) -> (f a,s,w))) instance (Monoid w,Monad m) => Applicative (RWST r w s m) instance (Monoid w,Monad m) => Monad (RWST r w s m) where join mm = RWST (\ ~(r,s) -> do ~(m,s',w) <- runRWST mm (r,s) ~(a,s'',w') <- runRWST m (r,s') return (a,s'',w+w')) instance (Monoid w,MonadFix m) => MonadFix (RWST r w s m) where mfix f = RWST (\x -> mfix (\ ~(a,_,_) -> runRWST (f a) x)) instance (Monoid w,MonadCont m) => MonadCont (RWST r w s m) where callCC f = RWST $ \(r,s) -> callCC $ \k -> runRWST (f (\a -> lift (k (a,s,zero)))) (r,s) deriving instance Semigroup (m (a,s,w)) => Semigroup (RWST r w s m a) deriving instance Monoid (m (a,s,w)) => Monoid (RWST r w s m a) deriving instance Ring (m (a,s,w)) => Ring (RWST r w s m a) instance (Monad m,Monoid w) => MonadState s (RWST r w s m) where get = RWST (\ ~(_,s) -> pure (s,s,zero) ) put s = RWST (\ _ -> pure ((),s,zero) ) modify f = RWST (\ ~(_,s) -> pure ((),f s,zero) ) instance (Monad m,Monoid w) => MonadReader r (RWST r w s m) where ask = RWST (\ ~(r,s) -> pure (r,s,zero) ) local f (RWST m) = RWST (\ ~(r,s) -> m (f r,s) ) instance (Monad m,Monoid w) => MonadWriter w (RWST r w s m) where tell w = RWST (\ ~(_,s) -> pure ((),s,w) ) listen (RWST m) = RWST (m >>> map (\ ~(a,s,w) -> ((w,a),s,w) ) ) censor (RWST m) = RWST (m >>> map (\ ~(~(a,f),s,w) -> (a,s,f w) ) ) instance Foldable m => Foldable (RWST Void w Void m) where fold (RWST m) = foldMap (\(w,_,_) -> w).m $ (vd,vd) instance Traversable m => Traversable (RWST Void w Void m) where sequence (RWST m) = map (RWST . const . map (\((s,w),a) -> (a,s,w))) . sequence . map (\(a,s,w) -> sequence ((s,w),a)) $ m (vd,vd) instance (Monoid w,MonadError e m) => MonadError e (RWST r w s m) where throw = lift.throw catch f (RWST m) = RWST (\x -> catch (flip runRWST x.f) (m x)) instance Monoid w => MonadTrans (RWST r w s) where lift m = RWST (\ ~(_,s) -> (,s,zero) <$> m) instance (Monoid w) => MonadInternal (RWST r w s) where internal f (RWST m) = RWST (\ x -> f (m x <&> \ ~(a,s,w) -> ((s,w),a) ) <&> \ ~((s,w),b) -> (b,s,w) ) {-| A simple State Monad -} class Monad m => MonadState s m | m -> s where get :: m s put :: s -> m () put = modify . const modify :: (s -> s) -> m () modify f = get >>= put . f instance MonadState (IO ()) IO where get = return unit put a = a modify f = put (f unit) type IOLens a = Lens' (IO ()) (IO a) _ioref :: IORef a -> IOLens a _ioref r = lens (const (readIORef r)) (\x a -> x >> a >>= writeIORef r) _mvar :: MVar a -> IOLens a _mvar r = lens (const (readMVar r)) (\x a -> x >> a >>= putMVar r) get_ = lift get ; put_ = lift . put ; modify_ = lift . modify newtype StateT s m a = StateT (RWST Void Void s m a) deriving (Unit,Functor,Applicative,Monad,MonadFix, MonadTrans,MonadInternal, MonadCont,MonadState s) type State s a = StateT s Id a instance MonadReader r m => MonadReader r (StateT s m) where ask = ask_ ; local = local_ instance MonadWriter w m => MonadWriter w (StateT s m) where tell = tell_ ; listen = listen_ ; censor = censor_ deriving instance MonadError e m => MonadError e (StateT s m) deriving instance Semigroup (m (a,s,Void)) => Semigroup (StateT s m a) deriving instance Monoid (m (a,s,Void)) => Monoid (StateT s m a) deriving instance Ring (m (a,s,Void)) => Ring (StateT s m a) _StateT :: Iso' (RWST Void Void s m a) (StateT s m a) _StateT = iso StateT (\ ~(StateT s) -> s) _stateT :: Functor m => Iso' (s -> m (s,a)) (StateT s m a) _stateT = _mapping (_mapping $ iso (\ ~(s,a) -> (a,s,zero) ) (\(a,s,_) -> (s,a))) ._promapping _iso._RWST._StateT eval = (map . map) snd exec = (map . map) fst _state :: Iso' (s -> (s,a)) (State s a) _state = _mapping _Id._stateT (=-) :: MonadState s m => Lens' s s' -> s' -> m () infixl 0 =-,=~ l =- x = modify (set l x) (=~) :: MonadState s m => Lens' s s' -> (s' -> s') -> m () l =~ f = modify (warp l f) gets :: MonadState s m => Lens' s s' -> m s' gets l = at l<$>get saving :: MonadState s m => Lens' s s' -> m a -> m a saving l st = gets l >>= \s -> st <* (l =- s) mapAccum f t = traverse (at _state<$>f) t^.._state mapAccum_ = (map.map.map) snd mapAccum mapAccumR f t = traverse (at (_state._Backwards)<$>f) t^.._state._Backwards mapAccumR_ = (map.map.map) snd mapAccumR push = mapAccum_ (,) pop = mapAccumR_ (,) withPrev a e = (,)<$>push e a<*>e withNext e a = (,)<$>e<*>pop e a class Monad m => MonadReader r m where ask :: m r local :: (r -> r) -> m a -> m a instance MonadReader r ((->) r) where ask = id ; local = (>>>) ask_ = lift ask ; local_ f = internal (local f) {-| A simple Reader monad -} newtype ReaderT r m a = ReaderT (RWST r Void Void m a) deriving (Functor,Unit,Applicative,Monad,MonadFix, MonadTrans,MonadInternal, MonadReader r,MonadCont) type Reader r a = ReaderT r Id a _readerT :: Functor m => Iso' (r -> m a) (ReaderT r m a) _readerT = iso readerT runReaderT where readerT f = ReaderT (RWST (\ ~(r,_) -> f r<&>(,vd,vd) )) runReaderT (ReaderT (RWST f)) r = f (r,vd) <&> \ ~(a,_,_) -> a _reader = _mapping _Id._readerT instance MonadState s m => MonadState s (ReaderT r m) where get = get_ ; put = put_ ; modify = modify_ instance MonadWriter w m => MonadWriter w (ReaderT r m) where tell = tell_ ; listen = listen_ ; censor = censor_ deriving instance Semigroup (m (a,Void,Void)) => Semigroup (ReaderT r m a) deriving instance Monoid (m (a,Void,Void)) => Monoid (ReaderT r m a) deriving instance Ring (m (a,Void,Void)) => Ring (ReaderT r m a) class (Monad m,Monoid w) => MonadWriter w m | m -> w where tell :: w -> m () listen :: m a -> m (w,a) censor :: m (a,w -> w) -> m a tell_ = lift . tell listen_ = internal (\m -> listen m <&> \(w,(c,a)) -> (c,(w,a)) ) censor_ = internal (\m -> censor (m <&> \(c,(a,f)) -> ((c,a),f))) instance Monoid w => MonadWriter w ((,) w) where tell w = (w,()) listen m@(w,_) = (w,m) censor ~(w,~(a,f)) = (f w,a) mute :: (MonadWriter w m,Monoid w) => m a -> m a mute m = censor (m<&>(,const zero)) intercept :: (MonadWriter w m,Monoid w) => m a -> m (w,a) intercept = listen >>> mute {-| A simple Writer monad -} newtype WriterT w m a = WriterT (RWST Void w Void m a) deriving (Unit,Functor,Applicative,Monad,MonadFix ,Foldable,Traversable ,MonadTrans,MonadInternal ,MonadWriter w,MonadCont) type Writer w a = WriterT w Id a instance (Monoid w,MonadReader r m) => MonadReader r (WriterT w m) where ask = ask_ ; local = local_ instance (Monoid w,MonadState r m) => MonadState r (WriterT w m) where get = get_ ; put = put_ ; modify = modify_ deriving instance Semigroup (m (a,Void,w)) => Semigroup (WriterT w m a) deriving instance Monoid (m (a,Void,w)) => Monoid (WriterT w m a) deriving instance Ring (m (a,Void,w)) => Ring (WriterT w m a) _writerT :: Functor m => Iso' (m (w,a)) (WriterT w m a) _writerT = iso writerT runWriterT where writerT w = WriterT (RWST (pure (w <&> \ ~(w,a) -> (a,vd,w) ))) runWriterT (WriterT (RWST m)) = m (vd,vd) <&> \ ~(a,_,w) -> (w,a) _writer = _Id._writerT {-| A simple continuation monad implementation -} class Monad m => MonadCont m where callCC :: ((a -> m b) -> m a) -> m a newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r } deriving (Semigroup,Monoid,Ring) type Cont r a = ContT r Id a instance Unit m => Unit (ContT r m) where pure a = ContT ($a) instance Functor f => Functor (ContT r f) where map f (ContT c) = ContT (\kb -> c (kb . f)) instance Applicative m => Applicative (ContT r m) where ContT cf <*> ContT ca = ContT (\kb -> cf (\f -> ca (\a -> kb (f a)))) instance Monad m => Monad (ContT r m) where ContT k >>= f = ContT (\cc -> k (\a -> runContT (f a) cc)) instance MonadTrans (ContT r) where lift m = ContT (m >>=) instance Monad m => MonadCont (ContT r m) where callCC f = ContT (\k -> runContT (f (\a -> ContT (\_ -> k a))) k) evalContT c = runContT c return evalCont = getId . evalContT instance MonadTrans Backwards where lift = Backwards instance MonadFix m => Monad (Backwards m) where join (Backwards ma) = Backwards$mfixing (\a -> liftA2 (,) (forwards a) ma) class Monad m => MonadList m where fork :: [a] -> m a instance MonadList [] where fork = id newtype ListT m a = ListT ((m`Compose`[]) a) deriving (Semigroup,Monoid, Functor,Applicative,Unit,Monad, Foldable,Traversable) _listT :: Iso' (m [a]) (ListT m a) _listT = iso (ListT . Compose) (\(ListT (Compose m)) -> m) instance Monad m => MonadList (ListT m) where fork = at _listT . return instance MonadFix m => MonadFix (ListT m) where mfix f = at _listT (mfix (at' _listT . f . head)) instance MonadTrans ListT where lift ma = (return<$>ma)^._listT instance MonadState s m => MonadState s (ListT m) where get = get_ ; modify = modify_ ; put = put_ instance MonadWriter w m => MonadWriter w (ListT m) where tell = lift.tell listen = _listT-.map sequence.listen.-_listT censor = _listT-.censor.map (\l -> (fst<$>l,compose (snd<$>l))).-_listT instance Monad m => MonadError Void (ListT m) where throw = const zero catch f m = (m^.._listT >>= \l -> case l of [] -> f vd^.._listT; l -> pure l)^._listT class Monad m => MonadError e m where throw :: e -> m Void catch :: (e -> m a) -> m a -> m a try d = catch (\x -> const d (x::Void)) instance MonadError e (Either e) where throw = Left catch f = f<|>Right instance MonadError Void [] where throw = const zero catch f [] = f vd catch _ l = l newtype EitherT e m a = EitherT ((m`Compose`Either e) a) deriving (Unit,Functor,Applicative,Monad,MonadFix ,Foldable,Traversable) eitherT = EitherT . Compose runEitherT (EitherT m) = getCompose m instance Applicative Maybe instance Monad Maybe where join = fold instance MonadError Void Maybe where throw = const Nothing catch f Nothing = f vd catch _ a = a instance Ex.Exception e => MonadError e IO where throw = Ex.throw catch = flip Ex.catch