{-# LANGUAGE MultiParamTypeClasses, TupleSections, Rank2Types, UndecidableInstances, FunctionalDependencies #-} module SimpleH.Monad( module SimpleH.Applicative, -- * The basic Monad interface Monad(..),MonadFix(..),MonadTrans(..), -- * Monad utilities Kleisli(..),_Kleisli, (=<<),(<=<),(>=>),(>>),(<*=),return, foldlM,foldrM,while,until, -- * Common monads -- ** The RWS Monad RWST(..),RWS, -- *** The State Monad MonadState(..), StateT,State, _stateT,eval,exec,_state, (=~),(=-),gets,saving, mapAccum,mapAccum_,mapAccumR,mapAccumR_,push,pop,withPrev,withNext, -- *** The Reader monad MonadReader(..), ReaderT,Reader, _readerT,_reader, -- *** The Writer monad MonadWriter(..), WriterT,Writer, _writerT,_writer, mute,intercept, -- ** The Continuation monad MonadCont(..), ContT(..),Cont, evalContT, evalCont, -- ** The List monad MonadList(..), ListT, _listT, -- ** The Error Monad MonadError(..),try, EitherT, eitherT,runEitherT, ) where import SimpleH.Classes import SimpleH.Applicative import SimpleH.Core hiding (flip) import SimpleH.Traversable import SimpleH.Lens import qualified Control.Exception as Ex import qualified Control.Monad.Fix as Fix instance (Traversable g,Monad f,Monad g) => Monad (Compose f g) where join = Compose .map join.join.map sequence.getCompose.map getCompose -- |The class of all monads that have a fixpoint class Monad m => MonadFix m where mfix :: (a -> m a) -> m a instance MonadFix Id where mfix = cfix instance MonadFix ((->) b) where mfix = cfix instance MonadFix [] where mfix f = fix (f . head) instance MonadFix (Either e) where mfix f = fix (f . either undefined id) instance MonadFix IO where mfix = Fix.mfix instance (Contravariant f,Monad f,Traversable g,MonadFix g) => MonadFix (Compose f g) where mfix f = Compose (map mfix (collect (getCompose . f))) cfix f = map fix (collect f) mfixing f = fst<$>mfix (\ ~(_,b) -> f b ) class MonadTrans t where lift :: Monad m => m a -> t m a class MonadTrans t => MonadInternal t where internal :: Monad m => (forall c. m (c,a) -> m (c,b)) -> (t m a -> t m b) newtype Kleisli m a b = Kleisli { runKleisli :: a -> m b } instance Monad m => Category (Kleisli m) where id = Kleisli pure Kleisli f . Kleisli g = Kleisli (\a -> g a >>= f) instance Monad m => Choice (Kleisli m) where Kleisli f <|> Kleisli g = Kleisli (f <|> g) instance Monad m => Split (Kleisli m) where Kleisli f <#> Kleisli g = Kleisli (\(a,c) -> (,)<$>f a<*>g c) instance Isomorphic (a -> m b) (a -> m c) (Kleisli m a b) (Kleisli m a c) where _iso = iso Kleisli runKleisli _Kleisli = _iso :: Iso' (a -> m b) (Kleisli m a b) folding :: (Foldable t,Monoid w) => Iso' (a -> c) w -> (b -> a -> c) -> a -> t b -> c folding i f e t = at (from i) (foldMap (at i . f) t) e foldlM = folding (_Kleisli._Endo._Dual) foldrM = folding (_Kleisli._Endo) while e = fix (\w -> e >>= maybe (return()) (const w)) until e = fix (\w -> e >>= maybe w return) infixr 2 >>,=<< infixr 1 <*= (>>) = (*>) (=<<) = flip (>>=) f <=< g = \a -> g a >>= f (>=>) = flip (<=<) a <*= f = a >>= \a -> f a >> return a return = pure newtype RWST r w s m a = RWST { runRWST :: (r,s) -> m (a,s,w) } type RWS r w s a = RWST r w s Id a _RWST :: Iso' ((r,s) -> m (a,s,w)) (RWST r w s m a) _RWST = iso RWST runRWST instance (Unit f,Monoid w) => Unit (RWST r w s f) where pure a = RWST (\ ~(_,s) -> pure (a,s,zero)) instance Functor f => Functor (RWST r w s f) where map f (RWST fa) = RWST (fa >>> map (\ ~(a,s,w) -> (f a,s,w))) instance (Monoid w,Monad m) => Applicative (RWST r w s m) instance (Monoid w,Monad m) => Monad (RWST r w s m) where join mm = RWST (\ ~(r,s) -> do ~(m,s',w) <- runRWST mm (r,s) ~(a,s'',w') <- runRWST m (r,s') return (a,s'',w+w')) instance (Monoid w,MonadFix m) => MonadFix (RWST r w s m) where mfix f = RWST (\x -> mfix (\ ~(a,_,_) -> runRWST (f a) x)) instance (Monoid w,MonadCont m) => MonadCont (RWST r w s m) where callCC f = RWST $ \(r,s) -> callCC $ \k -> runRWST (f (\a -> lift (k (a,s,zero)))) (r,s) deriving instance Semigroup (m (a,s,w)) => Semigroup (RWST r w s m a) deriving instance Monoid (m (a,s,w)) => Monoid (RWST r w s m a) deriving instance Ring (m (a,s,w)) => Ring (RWST r w s m a) instance (Monad m,Monoid w) => MonadState s (RWST r w s m) where get = RWST (\ ~(_,s) -> pure (s,s,zero) ) put s = RWST (\ ~(_,_) -> pure ((),s,zero) ) modify f = RWST (\ ~(_,s) -> pure ((),f s,zero) ) instance (Monad m,Monoid w) => MonadReader r (RWST r w s m) where ask = RWST (\ ~(r,s) -> pure (r,s,zero) ) local f (RWST m) = RWST (\ ~(r,s) -> m (f r,s) ) instance (Monad m,Monoid w) => MonadWriter w (RWST r w s m) where tell w = RWST (\ ~(_,s) -> pure ((),s,w) ) listen (RWST m) = RWST (m >>> map (\ ~(a,s,w) -> ((w,a),s,w) ) ) censor (RWST m) = RWST (m >>> map (\ ~(~(a,f),s,w) -> (a,s,f w) ) ) instance Foldable m => Foldable (RWST Void w Void m) where fold (RWST m) = foldMap (\(w,_,_) -> w).m $ (vd,vd) instance Traversable m => Traversable (RWST Void w Void m) where sequence (RWST m) = map (RWST . const . map (\((s,w),a) -> (a,s,w))) . sequence . map (\(a,s,w) -> sequence ((s,w),a)) $ m (vd,vd) instance (Monoid w,MonadError e m) => MonadError e (RWST r w s m) where throw = lift.throw catch f (RWST m) = RWST (\x -> catch (flip runRWST x.f) (m x)) instance Monoid w => MonadTrans (RWST r w s) where lift m = RWST (\ ~(_,s) -> (,s,zero) <$> m) instance (Monoid w) => MonadInternal (RWST r w s) where internal f (RWST m) = RWST (\ x -> f (m x <&> \ ~(a,s,w) -> ((s,w),a) ) <&> \ ~((s,w),b) -> (b,s,w) ) {-| A simple State Monad -} class Monad m => MonadState s m | m -> s where get :: m s put :: s -> m () put = modify . const modify :: (s -> s) -> m () modify f = get >>= put . f get_ = lift get ; put_ = lift . put ; modify_ = lift . modify newtype StateT s m a = StateT (RWST Void Void s m a) deriving (Unit,Functor,Applicative,Monad,MonadFix, MonadTrans,MonadInternal, MonadCont,MonadState s) type State s a = StateT s Id a instance MonadReader r m => MonadReader r (StateT s m) where ask = ask_ ; local = local_ instance MonadWriter w m => MonadWriter w (StateT s m) where tell = tell_ ; listen = listen_ ; censor = censor_ deriving instance MonadError e m => MonadError e (StateT s m) deriving instance Semigroup (m (a,s,Void)) => Semigroup (StateT s m a) deriving instance Monoid (m (a,s,Void)) => Monoid (StateT s m a) deriving instance Ring (m (a,s,Void)) => Ring (StateT s m a) _StateT :: Iso' (RWST Void Void s m a) (StateT s m a) _StateT = iso StateT (\ ~(StateT s) -> s) _stateT :: Functor m => Iso' (s -> m (s,a)) (StateT s m a) _stateT = _mapping (_mapping $ iso (\ ~(s,a) -> (a,s,zero) ) (\(a,s,_) -> (s,a))) ._promapping _iso._RWST._StateT eval = (map . map) snd exec = (map . map) fst _state :: Iso' (s -> (s,a)) (State s a) _state = _mapping _Id._stateT (=-) :: MonadState s m => Lens' s s' -> s' -> m () infixl 0 =-,=~ l =- x = modify (set l x) (=~) :: MonadState s m => Lens' s s' -> (s' -> s') -> m () l =~ f = modify (warp l f) gets :: MonadState s m => Lens' s s' -> m s' gets l = at l<$>get saving :: MonadState s m => Lens' s s' -> m a -> m a saving l st = gets l >>= \s -> st <* (l =- s) mapAccum f t = traverse (at _state<$>f) t^.._state mapAccum_ = (map.map.map) snd mapAccum mapAccumR f t = traverse (at (_state._Backwards)<$>f) t^.._state._Backwards mapAccumR_ = (map.map.map) snd mapAccumR push = mapAccum_ (,) pop = mapAccumR_ (,) withPrev a e = (,)<$>push e a<*>e withNext e a = (,)<$>e<*>pop e a class Monad m => MonadReader r m where ask :: m r local :: (r -> r) -> m a -> m a instance MonadReader r ((->) r) where ask = id ; local = (>>>) ask_ = lift ask ; local_ f = internal (local f) {-| A simple Reader monad -} newtype ReaderT r m a = ReaderT (RWST r Void Void m a) deriving (Functor,Unit,Applicative,Monad,MonadFix, MonadTrans,MonadInternal, MonadReader r,MonadCont) type Reader r a = ReaderT r Id a _readerT :: Functor m => Iso' (r -> m a) (ReaderT r m a) _readerT = iso readerT runReaderT where readerT f = ReaderT (RWST (\ ~(r,_) -> f r<&>(,vd,vd) )) runReaderT (ReaderT (RWST f)) r = f (r,vd) <&> \ ~(a,_,_) -> a _reader = _mapping _Id._readerT instance MonadState s m => MonadState s (ReaderT r m) where get = get_ ; put = put_ ; modify = modify_ instance MonadWriter w m => MonadWriter w (ReaderT r m) where tell = tell_ ; listen = listen_ ; censor = censor_ deriving instance Semigroup (m (a,Void,Void)) => Semigroup (ReaderT r m a) deriving instance Monoid (m (a,Void,Void)) => Monoid (ReaderT r m a) deriving instance Ring (m (a,Void,Void)) => Ring (ReaderT r m a) class (Monad m,Monoid w) => MonadWriter w m | m -> w where tell :: w -> m () listen :: m a -> m (w,a) censor :: m (a,w -> w) -> m a tell_ = lift . tell listen_ = internal (\m -> listen m <&> \(w,(c,a)) -> (c,(w,a)) ) censor_ = internal (\m -> censor (m <&> \(c,(a,f)) -> ((c,a),f))) instance Monoid w => MonadWriter w ((,) w) where tell w = (w,()) listen m@(w,_) = (w,m) censor ~(w,~(a,f)) = (f w,a) mute :: (MonadWriter w m,Monoid w) => m a -> m a mute m = censor (m<&>(,const zero)) intercept :: (MonadWriter w m,Monoid w) => m a -> m (w,a) intercept = listen >>> mute {-| A simple Writer monad -} newtype WriterT w m a = WriterT (RWST Void w Void m a) deriving (Unit,Functor,Applicative,Monad,MonadFix ,Foldable,Traversable ,MonadTrans,MonadInternal ,MonadWriter w,MonadCont) type Writer w a = WriterT w Id a instance (Monoid w,MonadReader r m) => MonadReader r (WriterT w m) where ask = ask_ ; local = local_ instance (Monoid w,MonadState r m) => MonadState r (WriterT w m) where get = get_ ; put = put_ ; modify = modify_ deriving instance Semigroup (m (a,Void,w)) => Semigroup (WriterT w m a) deriving instance Monoid (m (a,Void,w)) => Monoid (WriterT w m a) deriving instance Ring (m (a,Void,w)) => Ring (WriterT w m a) _writerT :: Functor m => Iso' (m (w,a)) (WriterT w m a) _writerT = iso writerT runWriterT where writerT w = WriterT (RWST (pure (w <&> \ ~(w,a) -> (a,vd,w) ))) runWriterT (WriterT (RWST m)) = m (vd,vd) <&> \ ~(a,_,w) -> (w,a) _writer = _Id._writerT {-| A simple continuation monad implementation -} class Monad m => MonadCont m where callCC :: ((a -> m b) -> m a) -> m a newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r } deriving (Semigroup,Monoid,Ring) type Cont r a = ContT r Id a instance Unit m => Unit (ContT r m) where pure a = ContT ($a) instance Functor f => Functor (ContT r f) where map f (ContT c) = ContT (\kb -> c (kb . f)) instance Applicative m => Applicative (ContT r m) where ContT cf <*> ContT ca = ContT (\kb -> cf (\f -> ca (\a -> kb (f a)))) instance Monad m => Monad (ContT r m) where ContT k >>= f = ContT (\cc -> k (\a -> runContT (f a) cc)) instance MonadTrans (ContT r) where lift m = ContT (m >>=) instance Monad m => MonadCont (ContT r m) where callCC f = ContT (\k -> runContT (f (\a -> ContT (\_ -> k a))) k) evalContT c = runContT c return evalCont = getId . evalContT instance MonadTrans Backwards where lift = Backwards instance MonadFix m => Monad (Backwards m) where join (Backwards ma) = Backwards$mfixing (\a -> liftA2 (,) (forwards a) ma) class Monad m => MonadList m where fork :: [a] -> m a instance MonadList [] where fork = id newtype ListT m a = ListT ((m`Compose`[]) a) deriving (Semigroup,Monoid, Functor,Applicative,Unit,Monad, Foldable,Traversable) _listT :: Iso' (m [a]) (ListT m a) _listT = iso (ListT . Compose) (\(ListT (Compose m)) -> m) instance Monad m => MonadList (ListT m) where fork = at _listT . return instance MonadFix m => MonadFix (ListT m) where mfix f = at _listT (mfix (at' _listT . f . head)) instance MonadTrans ListT where lift ma = (return<$>ma)^._listT instance MonadState s m => MonadState s (ListT m) where get = get_ ; modify = modify_ ; put = put_ instance MonadWriter w m => MonadWriter w (ListT m) where tell = lift.tell listen = _listT-.map sequence.listen.-_listT censor = _listT-.censor.map (\l -> (fst<$>l,compose (snd<$>l))).-_listT instance Monad m => MonadError Void (ListT m) where throw = const zero catch f m = (m^.._listT >>= \l -> case l of [] -> f vd^.._listT; l -> pure l)^._listT class Monad m => MonadError e m where throw :: e -> m Void catch :: (e -> m a) -> m a -> m a try d = catch (\x -> const d (x::Void)) instance MonadError e (Either e) where throw = Left catch f = f<|>Right instance MonadError Void [] where throw = const zero catch f [] = f vd catch _ l = l newtype EitherT e m a = EitherT ((m`Compose`Either e) a) deriving (Unit,Functor,Applicative,Monad,MonadFix ,Foldable,Traversable) eitherT = EitherT . Compose runEitherT (EitherT m) = getCompose m instance Applicative Maybe instance Monad Maybe where join = fold instance MonadError Void Maybe where throw = const Nothing catch f Nothing = f vd catch _ a = a instance Ex.Exception e => MonadError e IO where throw = Ex.throw catch = flip Ex.catch