module Control.Arrow.State (module Control.Arrow.State.Class, StateT (..)) where

import Prelude hiding ((.), id);

import Control.Arrow;
import Control.Arrow.Transformer;
import Control.Arrow.State.Class;
import Control.Category;
import Util;

newtype StateT s r a b = StateT { runStateT :: r (a, s) (b, s) };

instance ArrowTransformer (StateT s) where {
  lift = StateT . (*** id);
  tmap f = StateT . f . runStateT;
};

instance (Category r) => Category (StateT s r) where {
  id = StateT id;
  StateT x . StateT y = StateT (x . y);
};

instance (Arrow r) => Arrow (StateT s r) where {
  arr = StateT . arr . (*** id);
  first = StateT . (>>>) swap_snds_A . (<<<) swap_snds_A . (*** id) . runStateT;
};

instance (ArrowApply r) => ArrowApply (StateT s r) where {
  app = StateT $ arr (\ ((StateT f, x), s) -> (f, (x, s))) >>> app;
};

instance (Arrow r) => ArrowState s (StateT s r) where {
  get = StateT $ arr $ \ (_, s) -> (s,  s);
  put = StateT $ arr $ \ (s, _) -> ((), s);
};

instance (Arrow r, ArrowTransformer xT, Arrow (xT (StateT s r))) => ArrowState s (xT (StateT s r)) where {
  get = lift get;
  put = lift put;
};