module Control.Monad.Indexed.Trans.State where import Control.Applicative import Control.Monad ((>=>), MonadPlus (..)) import Control.Monad.Fix (MonadFix (..)) import Data.Functor.Indexed newtype StateT f i j a = StateT { runStateT :: i -> f (a, j) } deriving (Functor) mapStateT :: (f (a, j) -> g (b, k)) -> StateT f i j a -> StateT g i k b mapStateT f (StateT x) = StateT (f . x) modify :: Applicative p => (i -> j) -> StateT p i j i modify f = modifyF (pure . f) modifyF :: Functor f => (i -> f j) -> StateT f i j i modifyF = StateT . liftA2 fmap (,) get :: Applicative p => StateT p k k k get = modifyF pure put :: Applicative p => j -> StateT p i j () put = StateT . pure . pure . (,) () instance Monad m => IxApplicative (StateT m) where ipure a = StateT $ pure . (,) a StateT fm `iap` StateT xm = StateT $ \ i -> [(f x, k) | (f, j) <- fm i, (x, k) <- xm j] instance Monad m => IxMonad (StateT m) where ijoin = StateT . (>=> uncurry runStateT) . runStateT instance Monad m => Applicative (StateT m k k) where pure = ipure (<*>) = iap instance Monad m => Monad (StateT m k k) where (>>=) = flip ibind instance MonadPlus m => Alternative (StateT m k k) where empty = StateT (pure empty) StateT a <|> StateT b = StateT (liftA2 (<|>) a b) instance MonadPlus m => MonadPlus (StateT m k k) where mzero = empty mplus = (<|>) instance MonadFix m => MonadFix (StateT m k k) where mfix f = StateT $ mfix . \ k -> flip runStateT k . f . fst