{-# LANGUAGE UndecidableInstances #-}
module Algebra.Monad.State (
  -- * The State Monad
  MonadState(..),
  StateT,State,
  stateT,eval,exec,state,
  (=~),(=-),gets,saving,
  Next,Prev,
  mapAccum,mapAccum_,mapAccumR,mapAccumR_,push,pop,withPrev,withNext,

  -- * The State Arrow
  StateA(..),stateA,
  ) where

import Algebra.Monad.RWS
import Algebra.Monad.Base

instance MonadState (IO ()) IO where
  get = return unit
  put a = a
  modify f = put (f unit)

newtype StateT s m a = StateT (RWST Void Void s m a)
                     deriving (Unit,Functor,Applicative,Monad,MonadFix,
                               MonadTrans,MonadInternal,
                               MonadCont,MonadState s,MonadList)
type State s a = StateT s Id a
instance MonadReader r m => MonadReader r (StateT s m) where
  ask = ask_ ; local = local_
instance MonadWriter w m => MonadWriter w (StateT s m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
deriving instance MonadError e m => MonadError e (StateT s m)
deriving instance Semigroup (m (a,s,Void)) => Semigroup (StateT s m a)
deriving instance Monoid (m (a,s,Void)) => Monoid (StateT s m a)
deriving instance Ring (m (a,s,Void)) => Ring (StateT s m a)

_StateT :: Iso (StateT s m a) (StateT t n b) (RWST Void Void s m a) (RWST Void Void t n b)
_StateT = iso StateT (\ ~(StateT s) -> s)
stateT :: (Functor m,Functor n) => Iso (StateT s m a) (StateT t n b) (s -> m (s,a)) (t -> n (t,b))
stateT = _mapping (_mapping $ iso (\ ~(s,a) -> (a,s,zero) ) (\(a,s,_) -> (s,a)))
          ._promapping _iso._RWST._StateT
eval :: (Functor f, Functor f') => f (f' (a, b)) -> f (f' b)
eval = map2 snd
exec :: (Functor f, Functor f') => f (f' (a, b)) -> f (f' a)
exec = map2 fst
state :: Iso (State s a) (State t b) (s -> (s,a)) (t -> (t,b))
state = _mapping _Id.stateT

(=-) :: MonadState s m => Lens' s s' -> s' -> m ()
infixl 0 =-,=~
l =- x = modify (set l x)
(=~) :: MonadState s m => Lens' s s' -> (s' -> s') -> m ()
l =~ f = modify (warp l f)
gets :: MonadState s m => Lens' s s' -> m s'
gets l = by l<$>get

saving :: MonadState s m => Lens' s s' -> m a -> m a
saving l st = gets l >>= \s -> st <* (l =- s)

-- * The State Arrow
newtype StateA m s a = StateA (StateT s m a)
stateA :: Iso (StateA m s a) (StateA m' s' a') (StateT s m a) (StateT s' m' a')
stateA = iso StateA (\(StateA s) -> s)
instance Monad m => Category (StateA m) where
  id = StateA get
  StateA sbc . StateA sab = StateA $ (^.stateT) $ \a ->
    (sab^..stateT) a >>= \(a',b) -> (a',).snd <$> (sbc^..stateT) b
instance Monad m => Split (StateA m) where
  StateA sac <#> StateA sbd = StateA $ (^.stateT)
                              $ map2 (\((a',c),(b',d)) -> ((a',b'),(c,d)))
                              $ (Kleisli (sac^..stateT) <#> Kleisli (sbd^..stateT)) ^.. _Kleisli
instance Monad m => Choice (StateA m) where
  StateA sac <|> StateA sbc = StateA $ (^.stateT) $
                              l Left (sac^..stateT)<|>l Right (sbc^..stateT)
    where l = map2 . first

mapAccum :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b)
mapAccum f t = traverse (by state<$>f) t^..state
mapAccum_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b
mapAccum_ = (map.map.map) snd mapAccum
mapAccumR :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b)
mapAccumR f t = traverse (by (state._Backwards)<$>f) t^..state._Backwards
mapAccumR_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b
mapAccumR_ = (map.map.map) snd mapAccumR

push :: Traversable t => t a -> a -> t a
push = mapAccum_ (,)
pop :: Traversable t => t a -> a -> t a
pop = mapAccumR_ (,)

type Next a = a
type Prev a = a
withPrev :: Traversable t => a -> t a -> t (Prev a,a)
withPrev = flip (mapAccum_ (\a p -> (a,(p,a))))
withNext :: Traversable t => t a -> a -> t (a,Next a)
withNext = mapAccumR_ (\a p -> (a,(a,p)))