{-# LANGUAGE MultiParamTypeClasses, TupleSections, Rank2Types, UndecidableInstances, FunctionalDependencies #-}
module SimpleH.Monad(
  module SimpleH.Applicative,

  -- * The basic Monad interface
  Monad(..),MonadFix(..),MonadTrans(..),

  -- * Monad utilities
  Kleisli(..),_Kleisli,
  (=<<),(<=<),(>=>),(>>),(<*=),return,
  foldlM,foldrM,while,until,
  
  -- * Common monads
  -- ** The RWS Monad
  RWST(..),RWS,

  -- *** The State Monad
  MonadState(..),
  StateT,State,
  _stateT,eval,exec,_state,
  (=~),(=-),gets,saving,
  mapAccum,mapAccum_,mapAccumR,mapAccumR_,push,pop,withPrev,withNext,
  
  -- *** The Reader monad
  MonadReader(..),
  ReaderT,Reader,
  _readerT,_reader,

  -- *** The Writer monad
  MonadWriter(..),
  WriterT,Writer,
  _writerT,_writer,
  mute,intercept,

  -- ** The Continuation monad
  MonadCont(..),
  ContT(..),Cont,
  evalContT,
  evalCont,

  -- ** The List monad
  MonadList(..),
  ListT,
  _listT,

  -- ** The Error Monad
  MonadError(..),try,
  EitherT,
  eitherT,runEitherT,
  ) where

import SimpleH.Classes
import SimpleH.Applicative
import SimpleH.Core hiding (flip)
import SimpleH.Traversable
import SimpleH.Lens
import qualified Control.Exception as Ex
import qualified Control.Monad.Fix as Fix

instance (Traversable g,Monad f,Monad g) => Monad (Compose f g) where
  join = Compose .map join.join.map sequence.getCompose.map getCompose

-- |The class of all monads that have a fixpoint
class Monad m => MonadFix m where
  mfix :: (a -> m a) -> m a
instance MonadFix Id where mfix = cfix
instance MonadFix ((->) b) where mfix = cfix
instance MonadFix [] where mfix f = fix (f . head)
instance MonadFix (Either e) where mfix f = fix (f . either undefined id)
instance MonadFix IO where mfix = Fix.mfix
instance (Contravariant f,Monad f,Traversable g,MonadFix g) => MonadFix (Compose f g) where
  mfix f = Compose (map mfix (collect (getCompose . f)))
cfix f = map fix (collect f) 
mfixing f = fst<$>mfix (\ ~(_,b) -> f b )

class MonadTrans t where
  lift :: Monad m => m a -> t m a
class MonadTrans t => MonadInternal t where
  internal :: Monad m => (forall c. m (c,a) -> m (c,b)) ->
              (t m a -> t m b)

newtype Kleisli m a b = Kleisli { runKleisli :: a -> m b }
instance Monad m => Category (Kleisli m) where
  id = Kleisli pure
  Kleisli f . Kleisli g = Kleisli (\a -> g a >>= f)
instance Monad m => Choice (Kleisli m) where
  Kleisli f <|> Kleisli g = Kleisli (f <|> g)
instance Monad m => Split (Kleisli m) where
  Kleisli f <#> Kleisli g = Kleisli (\(a,c) -> (,)<$>f a<*>g c)
instance Isomorphic (a -> m b) (a -> m c) (Kleisli m a b) (Kleisli m a c) where
  _iso = iso Kleisli runKleisli
_Kleisli = _iso :: Iso' (a -> m b) (Kleisli m a b)

folding :: (Foldable t,Monoid w) => Iso' (a -> c) w -> (b -> a -> c) -> a -> t b -> c  
folding i f e t = at (from i) (foldMap (at i . f) t) e
foldlM = folding (_Kleisli._Endo._Dual)
foldrM = folding (_Kleisli._Endo)

while e = fix (\w -> e >>= maybe (return()) (const w))
until e = fix (\w -> e >>= maybe w return)

infixr 2 >>,=<<
infixr 1 <*=
(>>) = (*>)
(=<<) = flip (>>=)
f <=< g = \a -> g a >>= f
(>=>) = flip (<=<)
a <*= f = a >>= \a -> f a >> return a
return = pure

newtype RWST r w s m a = RWST { runRWST :: (r,s) -> m (a,s,w) }
type RWS r w s a = RWST r w s Id a

_RWST :: Iso' ((r,s) -> m (a,s,w)) (RWST r w s m a)
_RWST = iso RWST runRWST

instance (Unit f,Monoid w) => Unit (RWST r w s f) where
  pure a = RWST (\ ~(_,s) -> pure (a,s,zero))
instance Functor f => Functor (RWST r w s f) where
  map f (RWST fa) = RWST (fa >>> map (\ ~(a,s,w) -> (f a,s,w)))
instance (Monoid w,Monad m) => Applicative (RWST r w s m)
instance (Monoid w,Monad m) => Monad (RWST r w s m) where
  join mm = RWST (\ ~(r,s) -> do
                     ~(m,s',w) <- runRWST mm (r,s)
                     ~(a,s'',w') <- runRWST m (r,s')
                     return (a,s'',w+w'))
instance (Monoid w,MonadFix m) => MonadFix (RWST r w s m) where
  mfix f = RWST (\x -> mfix (\ ~(a,_,_) -> runRWST (f a) x))
instance (Monoid w,MonadCont m) => MonadCont (RWST r w s m) where
  callCC f = RWST $ \(r,s) ->
    callCC $ \k -> runRWST (f (\a -> lift (k (a,s,zero)))) (r,s)
deriving instance Semigroup (m (a,s,w)) => Semigroup (RWST r w s m a)
deriving instance Monoid (m (a,s,w)) => Monoid (RWST r w s m a)
deriving instance Ring (m (a,s,w)) => Ring (RWST r w s m a)
instance (Monad m,Monoid w) => MonadState s (RWST r w s m) where
  get = RWST (\ ~(_,s) -> pure (s,s,zero) )
  put s = RWST (\ ~(_,_) -> pure ((),s,zero) )
  modify f = RWST (\ ~(_,s) -> pure ((),f s,zero) )
instance (Monad m,Monoid w) => MonadReader r (RWST r w s m) where
  ask = RWST (\ ~(r,s) -> pure (r,s,zero) )
  local f (RWST m) = RWST (\ ~(r,s) -> m (f r,s) )
instance (Monad m,Monoid w) => MonadWriter w (RWST r w s m) where
  tell w = RWST (\ ~(_,s) -> pure ((),s,w) )
  listen (RWST m) = RWST (m >>> map (\ ~(a,s,w) -> ((w,a),s,w) ) )
  censor (RWST m) = RWST (m >>> map (\ ~(~(a,f),s,w) -> (a,s,f w) ) )
instance Foldable m => Foldable (RWST Void w Void m) where
  fold (RWST m) = foldMap (\(w,_,_) -> w).m $ (vd,vd)
instance Traversable m => Traversable (RWST Void w Void m) where
  sequence (RWST m) = map (RWST . const . map (\((s,w),a) -> (a,s,w)))
                      . sequence . map (\(a,s,w) -> sequence ((s,w),a))
                      $ m (vd,vd)
instance (Monoid w,MonadError e m) => MonadError e (RWST r w s m) where
  throw = lift.throw
  catch f (RWST m) = RWST (\x -> catch (flip runRWST x.f) (m x))
instance Monoid w => MonadTrans (RWST r w s) where
  lift m = RWST (\ ~(_,s) -> (,s,zero) <$> m)
instance (Monoid w) => MonadInternal (RWST r w s) where
  internal f (RWST m) = RWST (\ x -> f (m x <&> \ ~(a,s,w) -> ((s,w),a) )
                                     <&> \ ~((s,w),b) -> (b,s,w) )
  
{-| A simple State Monad  -}
class Monad m => MonadState s m | m -> s where
  get :: m s
  put :: s -> m ()
  put = modify . const
  modify :: (s -> s) -> m ()
  modify f = get >>= put . f
get_ = lift get ; put_ = lift . put ; modify_ = lift . modify  

newtype StateT s m a = StateT (RWST Void Void s m a)
                     deriving (Unit,Functor,Applicative,Monad,MonadFix,
                               MonadTrans,MonadInternal,
                               MonadCont,MonadState s)
type State s a = StateT s Id a
instance MonadReader r m => MonadReader r (StateT s m) where
  ask = ask_ ; local = local_
instance MonadWriter w m => MonadWriter w (StateT s m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
deriving instance MonadError e m => MonadError e (StateT s m)
deriving instance Semigroup (m (a,s,Void)) => Semigroup (StateT s m a)
deriving instance Monoid (m (a,s,Void)) => Monoid (StateT s m a)
deriving instance Ring (m (a,s,Void)) => Ring (StateT s m a)

_StateT :: Iso' (RWST Void Void s m a) (StateT s m a)
_StateT = iso StateT (\ ~(StateT s) -> s)
_stateT :: Functor m => Iso' (s -> m (s,a)) (StateT s m a)
_stateT = _mapping (_mapping $ iso (\ ~(s,a) -> (a,s,zero) ) (\(a,s,_) -> (s,a)))
          ._promapping _iso._RWST._StateT
eval = (map . map) snd
exec = (map . map) fst
_state :: Iso' (s -> (s,a)) (State s a)
_state = _mapping _Id._stateT

(=-) :: MonadState s m => Lens' s s' -> s' -> m ()
infixl 0 =-,=~
l =- x = modify (set l x)
(=~) :: MonadState s m => Lens' s s' -> (s' -> s') -> m ()
l =~ f = modify (warp l f)
gets :: MonadState s m => Lens' s s' -> m s'
gets l = at l<$>get

saving :: MonadState s m => Lens' s s' -> m a -> m a
saving l st = gets l >>= \s -> st <* (l =- s)

mapAccum f t = traverse (at _state<$>f) t^.._state
mapAccum_ = (map.map.map) snd mapAccum
mapAccumR f t = traverse (at (_state._Backwards)<$>f) t^.._state._Backwards
mapAccumR_ = (map.map.map) snd mapAccumR

push = mapAccum_ (,)
pop = mapAccumR_ (,)

withPrev a e = (,)<$>push e a<*>e
withNext e a = (,)<$>e<*>pop e a

class Monad m => MonadReader r m where
  ask :: m r
  local :: (r -> r) -> m a -> m a
instance MonadReader r ((->) r) where
  ask = id ; local = (>>>)
ask_ = lift ask ; local_ f = internal (local f)
{-| A simple Reader monad -}
newtype ReaderT r m a = ReaderT (RWST r Void Void m a) 
                      deriving (Functor,Unit,Applicative,Monad,MonadFix,
                                MonadTrans,MonadInternal,
                                MonadReader r,MonadCont)
type Reader r a = ReaderT r Id a

_readerT :: Functor m => Iso' (r -> m a) (ReaderT r m a)
_readerT = iso readerT runReaderT
  where readerT f = ReaderT (RWST (\ ~(r,_) -> f r<&>(,vd,vd) ))
        runReaderT (ReaderT (RWST f)) r = f (r,vd) <&> \ ~(a,_,_) -> a
_reader = _mapping _Id._readerT

instance MonadState s m => MonadState s (ReaderT r m) where
  get = get_ ; put = put_ ; modify = modify_
instance MonadWriter w m => MonadWriter w (ReaderT r m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
deriving instance Semigroup (m (a,Void,Void)) => Semigroup (ReaderT r m a)
deriving instance Monoid (m (a,Void,Void)) => Monoid (ReaderT r m a)
deriving instance Ring (m (a,Void,Void)) => Ring (ReaderT r m a)

class (Monad m,Monoid w) => MonadWriter w m | m -> w where
  tell :: w -> m ()
  listen :: m a -> m (w,a)
  censor :: m (a,w -> w) -> m a
tell_ = lift . tell
listen_ = internal (\m -> listen m <&> \(w,(c,a)) -> (c,(w,a)) )
censor_ = internal (\m -> censor (m <&> \(c,(a,f)) -> ((c,a),f)))
instance Monoid w => MonadWriter w ((,) w) where
  tell w = (w,())
  listen m@(w,_) = (w,m)
  censor ~(w,~(a,f)) = (f w,a)
  
mute :: (MonadWriter w m,Monoid w) => m a -> m a
mute m = censor (m<&>(,const zero))
intercept :: (MonadWriter w m,Monoid w) => m a -> m (w,a)
intercept = listen >>> mute

{-| A simple Writer monad -}
newtype WriterT w m a = WriterT (RWST Void w Void m a)
                      deriving (Unit,Functor,Applicative,Monad,MonadFix
                               ,Foldable,Traversable
                               ,MonadTrans,MonadInternal
                               ,MonadWriter w,MonadCont)
type Writer w a = WriterT w Id a
instance (Monoid w,MonadReader r m) => MonadReader r (WriterT w m) where
  ask = ask_ ; local = local_
instance (Monoid w,MonadState r m) => MonadState r (WriterT w m) where
  get = get_ ; put = put_ ; modify = modify_
deriving instance Semigroup (m (a,Void,w)) => Semigroup (WriterT w m a)
deriving instance Monoid (m (a,Void,w)) => Monoid (WriterT w m a)
deriving instance Ring (m (a,Void,w)) => Ring (WriterT w m a)

_writerT :: Functor m => Iso' (m (w,a)) (WriterT w m a)
_writerT = iso writerT runWriterT
  where writerT w = WriterT (RWST (pure (w <&> \ ~(w,a) -> (a,vd,w) )))
        runWriterT (WriterT (RWST m)) = m (vd,vd) <&> \ ~(a,_,w) -> (w,a)
_writer = _Id._writerT

{-| A simple continuation monad implementation  -}
class Monad m => MonadCont m where
  callCC :: ((a -> m b) -> m a) -> m a

newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }
                      deriving (Semigroup,Monoid,Ring)
type Cont r a = ContT r Id a
instance Unit m => Unit (ContT r m) where pure a = ContT ($a)
instance Functor f => Functor (ContT r f) where
  map f (ContT c) = ContT (\kb -> c (kb . f))
instance Applicative m => Applicative (ContT r m) where
  ContT cf <*> ContT ca = ContT (\kb -> cf (\f -> ca (\a -> kb (f a))))
instance Monad m => Monad (ContT r m) where
  ContT k >>= f = ContT (\cc -> k (\a -> runContT (f a) cc))
instance MonadTrans (ContT r) where
  lift m = ContT (m >>=)
instance Monad m => MonadCont (ContT r m) where
  callCC f = ContT (\k -> runContT (f (\a -> ContT (\_ -> k a))) k)

evalContT c = runContT c return
evalCont = getId . evalContT

instance MonadTrans Backwards where
  lift = Backwards
instance MonadFix m => Monad (Backwards m) where
  join (Backwards ma) = Backwards$mfixing (\a -> liftA2 (,) (forwards a) ma)

class Monad m => MonadList m where
  fork :: [a] -> m a
instance MonadList [] where fork = id
newtype ListT m a = ListT ((m`Compose`[]) a)
                    deriving (Semigroup,Monoid,
                              Functor,Applicative,Unit,Monad,
                              Foldable,Traversable)
_listT :: Iso' (m [a]) (ListT m a)
_listT = iso (ListT . Compose) (\(ListT (Compose m)) -> m)
instance Monad m => MonadList (ListT m) where
  fork = at _listT . return 
instance MonadFix m => MonadFix (ListT m) where
  mfix f = at _listT (mfix (at' _listT . f . head))
instance MonadTrans ListT where
  lift ma = (return<$>ma)^._listT
instance MonadState s m => MonadState s (ListT m) where
  get = get_ ; modify = modify_ ; put = put_
instance MonadWriter w m => MonadWriter w (ListT m) where
  tell = lift.tell
  listen = _listT-.map sequence.listen.-_listT
  censor = _listT-.censor.map (\l -> (fst<$>l,compose (snd<$>l))).-_listT
instance Monad m => MonadError Void (ListT m) where
  throw = const zero
  catch f m = (m^.._listT >>= \l -> case l of [] -> f vd^.._listT; l -> pure l)^._listT

class Monad m => MonadError e m where
  throw :: e -> m Void
  catch :: (e -> m a) -> m a -> m a
try d = catch (\x -> const d (x::Void))
instance MonadError e (Either e) where
  throw = Left
  catch f = f<|>Right
instance MonadError Void [] where
  throw = const zero
  catch f [] = f vd
  catch _ l = l
newtype EitherT e m a = EitherT ((m`Compose`Either e) a)
                      deriving (Unit,Functor,Applicative,Monad,MonadFix
                               ,Foldable,Traversable)
eitherT = EitherT . Compose
runEitherT (EitherT m) = getCompose m

instance Applicative Maybe
instance Monad Maybe where join = fold
instance MonadError Void Maybe where
  throw = const Nothing
  catch f Nothing = f vd
  catch _ a = a
instance Ex.Exception e => MonadError e IO where
  throw = Ex.throw
  catch = flip Ex.catch