{-# LANGUAGE MultiParamTypeClasses, TupleSections, Rank2Types, UndecidableInstances, FunctionalDependencies #-}
module SimpleH.Monad(
  module SimpleH.Applicative,

  -- * The basic Monad interface
  Monad(..),MonadFix(..),MonadTrans(..),

  -- * Monad utilities
  Kleisli(..),_Kleisli,
  (=<<),(<=<),(>=>),(>>),(<*=),return,
  foldlM,foldrM,while,until,
  bind2,bind3,(>>>=),(>>>>=),
  
  -- * Common monads
  -- ** The RWS Monad
  RWST(..),RWS,

  -- *** The State Monad
  MonadState(..),
  IOLens,_ioref,_mvar,
  StateT,State,
  _stateT,eval,exec,_state,
  (=~),(=-),gets,saving,
  mapAccum,mapAccum_,mapAccumR,mapAccumR_,push,pop,withPrev,withNext,
  
  -- *** The Reader monad
  MonadReader(..),
  ReaderT,Reader,
  _readerT,_reader,

  -- *** The Writer monad
  MonadWriter(..),
  WriterT,Writer,
  _writerT,_writer,
  mute,intercept,

  -- ** The Continuation monad
  MonadCont(..),
  ContT(..),Cont,
  evalContT,
  evalCont,

  -- ** The List monad
  MonadList(..),
  ListT,
  _listT,

  -- ** The Error Monad
  MonadError(..),try,
  EitherT,
  _eitherT
  ) where

import SimpleH.Classes
import SimpleH.Applicative
import SimpleH.Core hiding (flip)
import SimpleH.Traversable
import SimpleH.Lens
import qualified Control.Exception as Ex
import qualified Control.Monad.Fix as Fix
import Data.IORef
import Control.Concurrent

instance (Traversable g,Monad f,Monad g) => Monad (f:.:g) where
  join = Compose .map join.join.map sequence.getCompose.map getCompose

-- |The class of all monads that have a fixpoint
class Monad m => MonadFix m where
  mfix :: (a -> m a) -> m a
instance MonadFix Id where mfix = cfix
instance MonadFix ((->) b) where mfix = cfix
instance MonadFix [] where mfix f = fix (f . head)
instance MonadFix (Either e) where mfix f = fix (f . either undefined id)
instance MonadFix IO where mfix = Fix.mfix
instance (MonadFix f,Traversable g,Monad g) => MonadFix (f:.:g) where
  mfix f = Compose $ mfix (map join . traverse (getCompose . f))
cfix :: Contravariant c => (a -> c a) -> c a
cfix f = map fix (collect f)

mfixing :: MonadFix f => (b -> f (a, b)) -> f a
mfixing f = fst<$>mfix (\ ~(_,b) -> f b )

class MonadTrans t where
  lift :: Monad m => m a -> t m a
class MonadTrans t => MonadInternal t where
  internal :: Monad m => (forall c. m (c,a) -> m (c,b)) ->
              (t m a -> t m b)

newtype Kleisli m a b = Kleisli { runKleisli :: a -> m b }
instance Monad m => Category (Kleisli m) where
  id = Kleisli pure
  Kleisli f . Kleisli g = Kleisli (\a -> g a >>= f)
instance Monad m => Choice (Kleisli m) where
  Kleisli f <|> Kleisli g = Kleisli (f <|> g)
instance Monad m => Split (Kleisli m) where
  Kleisli f <#> Kleisli g = Kleisli (\(a,c) -> (,)<$>f a<*>g c)
instance Isomorphic (a -> m b) (c -> m' d) (Kleisli m a b) (Kleisli m' c d) where
  _iso = iso Kleisli runKleisli
_Kleisli :: Iso (Kleisli m a b) (Kleisli m' c d) (a -> m b) (c -> m' d)
_Kleisli = _iso 

folding :: (Foldable t,Monoid w) => Iso' (a -> c) w -> (b -> a -> c) -> a -> t b -> c  
folding i f e t = at (from i) (foldMap (at i . f) t) e
foldlM :: (Foldable t,Monad m) => (b -> a -> m a) -> a -> t b -> m a
foldlM = folding (_Kleisli._Endo._Dual)
foldrM :: (Foldable t,Monad m) => (b -> a -> m a) -> a -> t b -> m a
foldrM = folding (_Kleisli._Endo)

while :: Monad m => m (Maybe a) -> m ()
while e = fix (\w -> e >>= maybe unit (const w))
until :: Monad m => m (Maybe a) -> m a
until e = fix (\w -> e >>= maybe w return)

bind2 :: Monad m => (a -> b -> m c) -> m a -> m b -> m c
bind2 f a b = join (f<$>a<*>b)
(>>>=) :: Monad m => (m a,m b) -> (a -> b -> m c) -> m c
(a,b) >>>= f = bind2 f a b
bind3 :: Monad m => (a -> b -> c -> m d) -> m a -> m b -> m c -> m d
bind3 f a b c = join (f<$>a<*>b<*>c)
(>>>>=) :: Monad m => (m a,m b,m c) -> (a -> b -> c -> m d) -> m d
(a,b,c) >>>>= f = bind3 f a b c

infixr 2 >>,=<<
infixr 1 <*=
(>>) :: Applicative f => f a -> f b -> f b
(>>) = (*>)
(=<<) :: Monad m => (a -> m b) -> m a -> m b
(=<<) = flip (>>=)
(<=<) :: Monad m => (b -> m c) -> (a -> m b) -> (a -> m c)
f <=< g = \a -> g a >>= f
(>=>) :: Monad m => (a -> m b) -> (b -> m c) -> (a -> m c)
(>=>) = flip (<=<)
(<*=) :: Monad m => m a -> (a -> m b) -> m a
a <*= f = a >>= (>>)<$>f<*>return
return :: Unit f => a -> f a
return = pure

newtype RWST r w s m a = RWST { runRWST :: (r,s) -> m (a,s,w) }
type RWS r w s a = RWST r w s Id a

_RWST :: Iso (RWST r w s m a) (RWST r' w' s' m' a')
         ((r,s) -> m (a,s,w)) ((r',s') -> m' (a',s',w'))
_RWST = iso RWST runRWST

instance (Unit f,Monoid w) => Unit (RWST r w s f) where
  pure a = RWST (\ ~(_,s) -> pure (a,s,zero))
instance Functor f => Functor (RWST r w s f) where
  map f (RWST fa) = RWST (fa >>> map (\ ~(a,s,w) -> (f a,s,w)))
instance (Monoid w,Monad m) => Applicative (RWST r w s m)
instance (Monoid w,Monad m) => Monad (RWST r w s m) where
  join mm = RWST (\ ~(r,s) -> do
                     ~(m,s',w) <- runRWST mm (r,s)
                     ~(a,s'',w') <- runRWST m (r,s')
                     return (a,s'',w+w'))
instance (Monoid w,MonadFix m) => MonadFix (RWST r w s m) where
  mfix f = RWST (\x -> mfix (\ ~(a,_,_) -> runRWST (f a) x))
instance (Monoid w,MonadCont m) => MonadCont (RWST r w s m) where
  callCC f = RWST $ \(r,s) ->
    callCC $ \k -> runRWST (f (\a -> lift (k (a,s,zero)))) (r,s)
deriving instance Semigroup (m (a,s,w)) => Semigroup (RWST r w s m a)
deriving instance Monoid (m (a,s,w)) => Monoid (RWST r w s m a)
deriving instance Ring (m (a,s,w)) => Ring (RWST r w s m a)
instance (Monad m,Monoid w) => MonadState s (RWST r w s m) where
  get = RWST (\ ~(_,s) -> pure (s,s,zero) )
  put s = RWST (\ _ -> pure ((),s,zero) )
  modify f = RWST (\ ~(_,s) -> pure ((),f s,zero) )
instance (Monad m,Monoid w) => MonadReader r (RWST r w s m) where
  ask = RWST (\ ~(r,s) -> pure (r,s,zero) )
  local f (RWST m) = RWST (\ ~(r,s) -> m (f r,s) )
instance (Monad m,Monoid w) => MonadWriter w (RWST r w s m) where
  tell w = RWST (\ ~(_,s) -> pure ((),s,w) )
  listen (RWST m) = RWST (m >>> map (\ ~(a,s,w) -> ((w,a),s,w) ) )
  censor (RWST m) = RWST (m >>> map (\ ~(~(a,f),s,w) -> (a,s,f w) ) )
instance Foldable m => Foldable (RWST Void w Void m) where
  fold (RWST m) = foldMap (\(w,_,_) -> w).m $ (vd,vd)
instance Traversable m => Traversable (RWST Void w Void m) where
  sequence (RWST m) = map (RWST . const . map (\((s,w),a) -> (a,s,w)))
                      . sequence . map (\(a,s,w) -> sequence ((s,w),a))
                      $ m (vd,vd)
instance (Monoid w,MonadError e m) => MonadError e (RWST r w s m) where
  throw = lift.throw
  catch f (RWST m) = RWST (\x -> catch (flip runRWST x.f) (m x))
instance Monoid w => MonadTrans (RWST r w s) where
  lift m = RWST (\ ~(_,s) -> (,s,zero) <$> m)
instance (Monoid w) => MonadInternal (RWST r w s) where
  internal f (RWST m) = RWST (\ x -> f (m x <&> \ ~(a,s,w) -> ((s,w),a) )
                                     <&> \ ~((s,w),b) -> (b,s,w) )
  
{-| A simple State Monad  -}
class Monad m => MonadState s m | m -> s where
  get :: m s
  put :: s -> m ()
  put = modify . const
  modify :: (s -> s) -> m ()
  modify f = get >>= put . f
instance MonadState (IO ()) IO where
  get = return unit
  put a = a
  modify f = put (f unit)
type IOLens a = Lens' (IO ()) (IO a)
_ioref :: IORef a -> IOLens a
_ioref r = lens (const (readIORef r)) (\x a -> x >> a >>= writeIORef r)
_mvar :: MVar a -> IOLens a
_mvar r = lens (const (readMVar r)) (\x a -> x >> a >>= putMVar r)

get_ :: (MonadTrans t, MonadState a m) => t m a
get_ = lift get
put_ :: (MonadTrans t, MonadState s m) => s -> t m ()
put_ = lift . put
modify_ :: (MonadTrans t, MonadState s m) => (s -> s) -> t m ()
modify_ = lift . modify  

newtype StateT s m a = StateT (RWST Void Void s m a)
                     deriving (Unit,Functor,Applicative,Monad,MonadFix,
                               MonadTrans,MonadInternal,
                               MonadCont,MonadState s)
type State s a = StateT s Id a
instance MonadReader r m => MonadReader r (StateT s m) where
  ask = ask_ ; local = local_
instance MonadWriter w m => MonadWriter w (StateT s m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
deriving instance MonadError e m => MonadError e (StateT s m)
deriving instance Semigroup (m (a,s,Void)) => Semigroup (StateT s m a)
deriving instance Monoid (m (a,s,Void)) => Monoid (StateT s m a)
deriving instance Ring (m (a,s,Void)) => Ring (StateT s m a)

_StateT :: Iso (StateT s m a) (StateT t n b) (RWST Void Void s m a) (RWST Void Void t n b)
_StateT = iso StateT (\ ~(StateT s) -> s)
_stateT :: (Functor m,Functor n) => Iso (StateT s m a) (StateT t n b) (s -> m (s,a)) (t -> n (t,b))
_stateT = _mapping (_mapping $ iso (\ ~(s,a) -> (a,s,zero) ) (\(a,s,_) -> (s,a)))
          ._promapping _iso._RWST._StateT
eval :: (Functor f, Functor f') => f (f' (a, b)) -> f (f' b)
eval = map2 snd
exec :: (Functor f, Functor f') => f (f' (a, b)) -> f (f' a)
exec = map2 fst
_state :: Iso (State s a) (State t b) (s -> (s,a)) (t -> (t,b))
_state = _mapping _Id._stateT

(=-) :: MonadState s m => Lens' s s' -> s' -> m ()
infixl 0 =-,=~
l =- x = modify (set l x)
(=~) :: MonadState s m => Lens' s s' -> (s' -> s') -> m ()
l =~ f = modify (warp l f)
gets :: MonadState s m => Lens' s s' -> m s'
gets l = at l<$>get

saving :: MonadState s m => Lens' s s' -> m a -> m a
saving l st = gets l >>= \s -> st <* (l =- s)

mapAccum :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b)
mapAccum f t = traverse (at _state<$>f) t^.._state
mapAccum_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b
mapAccum_ = (map.map.map) snd mapAccum
mapAccumR :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b)
mapAccumR f t = traverse (at (_state._Backwards)<$>f) t^.._state._Backwards
mapAccumR_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b
mapAccumR_ = (map.map.map) snd mapAccumR

push :: Traversable t => t a -> a -> t a
push = mapAccum_ (,)
pop :: Traversable t => t a -> a -> t a
pop = mapAccumR_ (,)

withPrev :: Traversable t => a -> t a -> t (a,a)
withPrev = flip (mapAccum_ (\a p -> (a,(p,a))))
withNext :: Traversable t => t a -> a -> t (a,a)
withNext = mapAccumR_ (\a p -> (a,(p,a)))

class Monad m => MonadReader r m where
  ask :: m r
  local :: (r -> r) -> m a -> m a
instance MonadReader r ((->) r) where
  ask = id ; local = (>>>)
ask_ :: (MonadTrans t, MonadReader a m) => t m a
ask_ = lift ask
local_ :: (MonadInternal t, MonadReader r m) => (r -> r) -> t m a -> t m a
local_ f = internal (local f)
{-| A simple Reader monad -}
newtype ReaderT r m a = ReaderT (RWST r Void Void m a) 
                      deriving (Functor,Unit,Applicative,Monad,MonadFix,
                                MonadTrans,MonadInternal,
                                MonadReader r,MonadCont)
type Reader r a = ReaderT r Id a

_readerT :: (Functor m,Functor m') => Iso (ReaderT r m a) (ReaderT r' m' b) (r -> m a) (r' -> m' b)
_readerT = iso readerT runReaderT
  where readerT f = ReaderT (RWST (\ ~(r,_) -> f r<&>(,vd,vd) ))
        runReaderT (ReaderT (RWST f)) r = f (r,vd) <&> \ ~(a,_,_) -> a
_reader :: Iso (Reader r a) (Reader r' b) (r -> a) (r' -> b)
_reader = _mapping _Id._readerT

instance MonadState s m => MonadState s (ReaderT r m) where
  get = get_ ; put = put_ ; modify = modify_
instance MonadWriter w m => MonadWriter w (ReaderT r m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
deriving instance Semigroup (m (a,Void,Void)) => Semigroup (ReaderT r m a)
deriving instance Monoid (m (a,Void,Void)) => Monoid (ReaderT r m a)
deriving instance Ring (m (a,Void,Void)) => Ring (ReaderT r m a)

class (Monad m,Monoid w) => MonadWriter w m | m -> w where
  tell :: w -> m ()
  listen :: m a -> m (w,a)
  censor :: m (a,w -> w) -> m a

tell_ :: (MonadWriter w m, MonadTrans t) => w -> t m ()
tell_ = lift . tell
listen_ :: (MonadInternal t, MonadWriter w m) => t m a -> t m (w, a)
listen_ = internal (\m -> listen m <&> \(w,(c,a)) -> (c,(w,a)) )
censor_ :: (MonadInternal t, MonadWriter w m) => t m (a, w -> w) -> t m a
censor_ = internal (\m -> censor (m <&> \(c,(a,f)) -> ((c,a),f)))
instance Monoid w => MonadWriter w ((,) w) where
  tell w = (w,())
  listen m@(w,_) = (w,m)
  censor ~(w,~(a,f)) = (f w,a)
  
mute :: (MonadWriter w m,Monoid w) => m a -> m a
mute m = censor (m<&>(,const zero))
intercept :: (MonadWriter w m,Monoid w) => m a -> m (w,a)
intercept = listen >>> mute

{-| A simple Writer monad -}
newtype WriterT w m a = WriterT (RWST Void w Void m a)
                      deriving (Unit,Functor,Applicative,Monad,MonadFix
                               ,Foldable,Traversable
                               ,MonadTrans,MonadInternal
                               ,MonadWriter w,MonadCont)
type Writer w a = WriterT w Id a
instance (Monoid w,MonadReader r m) => MonadReader r (WriterT w m) where
  ask = ask_ ; local = local_
instance (Monoid w,MonadState r m) => MonadState r (WriterT w m) where
  get = get_ ; put = put_ ; modify = modify_
deriving instance Semigroup (m (a,Void,w)) => Semigroup (WriterT w m a)
deriving instance Monoid (m (a,Void,w)) => Monoid (WriterT w m a)
deriving instance Ring (m (a,Void,w)) => Ring (WriterT w m a)

_writerT :: (Functor m,Functor m') => Iso (WriterT w m a) (WriterT w' m' b) (m (w,a)) (m' (w',b))
_writerT = iso writerT runWriterT
  where writerT mw = WriterT (RWST (pure (mw <&> \ ~(w,a) -> (a,vd,w) )))
        runWriterT (WriterT (RWST m)) = m (vd,vd) <&> \ ~(a,_,w) -> (w,a)
_writer :: Iso (Writer w a) (Writer w' b) (w,a) (w',b)
_writer = _Id._writerT

{-| A simple continuation monad implementation  -}
class Monad m => MonadCont m where
  callCC :: ((a -> m b) -> m a) -> m a

newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }
                      deriving (Semigroup,Monoid,Ring)
type Cont r a = ContT r Id a
instance Unit m => Unit (ContT r m) where pure a = ContT ($a)
instance Functor f => Functor (ContT r f) where
  map f (ContT c) = ContT (\kb -> c (kb . f))
instance Applicative m => Applicative (ContT r m) where
  ContT cf <*> ContT ca = ContT (\kb -> cf (\f -> ca (\a -> kb (f a))))
instance Monad m => Monad (ContT r m) where
  ContT k >>= f = ContT (\cc -> k (\a -> runContT (f a) cc))
instance MonadTrans (ContT r) where
  lift m = ContT (m >>=)
instance Monad m => MonadCont (ContT r m) where
  callCC f = ContT (\k -> runContT (f (\a -> ContT (\_ -> k a))) k)

evalContT :: Unit m => ContT r m r -> m r
evalContT c = runContT c return
evalCont :: Cont r r -> r
evalCont = getId . evalContT

instance MonadTrans Backwards where
  lift = Backwards
instance MonadFix m => Monad (Backwards m) where
  join (Backwards ma) = Backwards$mfixing (\a -> liftA2 (,) (forwards a) ma)

class Monad m => MonadList m where
  fork :: [a] -> m a
instance MonadList [] where fork = id
newtype ListT m a = ListT ((m:.:[]) a)
                    deriving (Semigroup,Monoid,
                              Functor,Applicative,Unit,Monad,
                              Foldable,Traversable)
_listT :: Iso (ListT m a) (ListT m' a') (m [a]) (m' [a'])
_listT = iso (ListT . Compose) (\(ListT (Compose m)) -> m)
instance Monad m => MonadList (ListT m) where
  fork = at _listT . return 
instance MonadFix m => MonadFix (ListT m) where
  mfix f = at _listT (mfix (at' _listT . f . head))
instance MonadTrans ListT where
  lift ma = (return<$>ma)^._listT
instance MonadState s m => MonadState s (ListT m) where
  get = get_ ; modify = modify_ ; put = put_
instance MonadWriter w m => MonadWriter w (ListT m) where
  tell = lift.tell
  listen = _listT-.map sequence.listen.-_listT
  censor = _listT-.censor.map (\l -> (fst<$>l,compose (snd<$>l))).-_listT
instance Monad m => MonadError Void (ListT m) where
  throw = const zero
  catch f mm = mm & _listT %%~ (\m -> m >>= \_l -> case _l of
                                   [] -> f vd^.._listT; l -> pure l)

class Monad m => MonadError e m where
  throw :: e -> m a
  catch :: (e -> m a) -> m a -> m a
try :: MonadError Void m => m a -> m a -> m a
try d = catch (\x -> const d (x::Void))
instance MonadError e (Either e) where
  throw = Left
  catch f = f<|>Right
instance MonadError Void [] where
  throw = const zero
  catch f [] = f vd
  catch _ l = l
newtype EitherT e m a = EitherT ((m:.:Either e) a)
                      deriving (Unit,Functor,Applicative,Monad,MonadFix
                               ,Foldable,Traversable)
instance MonadTrans (EitherT e) where
  lift m = (pure<$>m)^._eitherT
_eitherT :: (Functor m) => Iso (EitherT e m a) (EitherT f m b) (m (e:+:a)) (m (f:+:b))                              
_eitherT = iso (EitherT . Compose) (\(EitherT (Compose e)) -> e)

instance Applicative Maybe
instance Monad Maybe where join = fold
instance MonadError Void Maybe where
  throw = const Nothing
  catch f Nothing = f vd
  catch _ a = a
instance Ex.Exception e => MonadError e IO where
  throw = Ex.throw
  catch = flip Ex.catch