{-# LANGUAGE UndecidableInstances #-}
module Algebra.Monad.State (
  -- * The State Monad
  MonadState(..),
  StateT,State,
  stateT,eval,exec,state,
  (=~),(=-),(^>=),gets,getl,saving,
  Next,Prev,
  mapAccum,mapAccum_,mapAccumR,mapAccumR_,push,pop,withPrev,withNext,

  -- * The State Arrow
  StateA(..),stateA,
  ) where

import Algebra.Monad.RWS
import Algebra.Monad.Base

instance MonadState (IO ()) IO where
  get = return unit
  put a = a
  modify f = put (f unit)

newtype StateT s m a = StateT (RWST Void Void s m a)
                     deriving (Unit,Functor,Applicative,Monad,MonadFix,
                               MonadTrans,MonadInternal,
                               MonadCont,MonadState s,MonadList)
type State s a = StateT s Id a
instance MonadReader r m => MonadReader r (StateT s m) where
  ask = ask_ ; local = local_
instance MonadWriter w m => MonadWriter w (StateT s m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
instance (MonadWriterAcc w acc m) => MonadWriterAcc w acc (StateT s m) where
  getAcc = lift getAcc
deriving instance MonadError e m => MonadError e (StateT s m)
deriving instance Semigroup (m (a,s,Void)) => Semigroup (StateT s m a)
deriving instance Monoid (m (a,s,Void)) => Monoid (StateT s m a)
deriving instance Semiring (m (a,s,Void)) => Semiring (StateT s m a)
deriving instance Ring (m (a,s,Void)) => Ring (StateT s m a)
deriving instance (Monad m,MonadFuture n m) => MonadFuture n (StateT s m)

_StateT :: Iso (StateT s m a) (StateT t n b) (RWST Void Void s m a) (RWST Void Void t n b)
_StateT = iso StateT (\ ~(StateT s) -> s)
stateT :: (Functor m,Functor n) => Iso (StateT s m a) (StateT t n b) (s -> m (s,a)) (t -> n (t,b))
stateT = mapping (mapping $ iso (\ ~(s,a) -> (a,s,zero) ) (\(a,s,_) -> (s,a)))
          .promapping i'_._RWST._StateT
eval :: (s ->  (a, b)) -> (s -> b)
eval = map snd
exec :: Functor f => f (a, b) -> f a
exec = map fst
state :: Iso (State s a) (State t b) (s -> (s,a)) (t -> (t,b))
state = mapping i'Id.stateT

(=-) :: MonadState s m => Traversal' s s' -> s' -> m ()
infixl 0 =-,=~
l =- x = modify (set l x)
(=~) :: MonadState s m => Traversal' s a -> (a -> a) -> m ()
l =~ f = modify (warp l f)
(^>=) :: MonadState s m => LensLike m a a s s -> (a -> m ()) -> m ()
l ^>= k = get >>= \s -> forl_ l s k
gets :: MonadState s m => (s -> a) -> m a
gets = (get<&>) 
getl :: MonadState s m => Getter' s a -> m a
getl l = by l<$>get

saving :: MonadState s m => Lens' s s' -> m a -> m a
saving l st = getl l >>= \s -> st <* (l =- s)

-- * The State Arrow
newtype StateA m s a = StateA (StateT s m a)
stateA :: Iso (StateA m s a) (StateA m' s' a') (StateT s m a) (StateT s' m' a')
stateA = iso StateA (\(StateA s) -> s)
instance Monad m => Category (StateA m) where
  id = StateA get
  StateA sbc . StateA sab = StateA $ (^.stateT) $ \a ->
    (sab^..stateT) a >>= \(a',b) -> (a',).snd <$> (sbc^..stateT) b
instance Monad m => Split (StateA m) where
  StateA sac <#> StateA sbd = StateA $ (^.stateT)
                              $ map2 (\((a',c),(b',d)) -> ((a',b'),(c,d)))
                              $ (Kleisli (sac^..stateT) <#> Kleisli (sbd^..stateT)) ^.. i'Kleisli
instance Monad m => Choice (StateA m) where
  StateA sac <|> StateA sbc = StateA $ (^.stateT) $
                              l Left (sac^..stateT)<|>l Right (sbc^..stateT)
    where l = map2 . first

mapAccum :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b)
mapAccum f t = traverse (by state<$>f) t^..state
mapAccum_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b
mapAccum_ = (map.map.map) snd mapAccum
mapAccumR :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b)
mapAccumR f t = traverse (by (state.i'Backwards)<$>f) t^..state.i'Backwards
mapAccumR_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b
mapAccumR_ = (map.map.map) snd mapAccumR

push :: Traversable t => t a -> a -> t a
push = mapAccum_ (,)
pop :: Traversable t => t a -> a -> t a
pop = mapAccumR_ (,)

type Next a = a
type Prev a = a
withPrev :: Traversable t => a -> t a -> t (Prev a,a)
withPrev = flip (mapAccum_ (\a p -> (a,(p,a))))
withNext :: Traversable t => t a -> a -> t (a,Next a)
withNext = mapAccumR_ (\a p -> (a,(a,p)))