-- | The indexed state transformer: each @'StateT' _ i j@ term takes an input of type @i@ and gives an output of type @j@.

module Control.Monad.Indexed.Trans.State where

import Prelude hiding ((<*>), Monad (..))
import Control.Applicative (Alternative (..))
import qualified Control.Applicative as Base
import qualified Control.Monad as Base
import qualified Control.Monad.Fix as Base
import Control.Monad.Indexed.Signatures
import Data.Functor.Indexed

newtype StateT f i j a = StateT { StateT f i j a -> i -> f (a, j)
runStateT :: i -> f (a, j) }
  deriving (a -> StateT f i j b -> StateT f i j a
(a -> b) -> StateT f i j a -> StateT f i j b
(forall a b. (a -> b) -> StateT f i j a -> StateT f i j b)
-> (forall a b. a -> StateT f i j b -> StateT f i j a)
-> Functor (StateT f i j)
forall a b. a -> StateT f i j b -> StateT f i j a
forall a b. (a -> b) -> StateT f i j a -> StateT f i j b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (f :: * -> *) i j a b.
Functor f =>
a -> StateT f i j b -> StateT f i j a
forall (f :: * -> *) i j a b.
Functor f =>
(a -> b) -> StateT f i j a -> StateT f i j b
<$ :: a -> StateT f i j b -> StateT f i j a
$c<$ :: forall (f :: * -> *) i j a b.
Functor f =>
a -> StateT f i j b -> StateT f i j a
fmap :: (a -> b) -> StateT f i j a -> StateT f i j b
$cfmap :: forall (f :: * -> *) i j a b.
Functor f =>
(a -> b) -> StateT f i j a -> StateT f i j b
Functor)

lift :: Functor f => f a -> StateT f k k a
lift :: f a -> StateT f k k a
lift xm :: f a
xm = (k -> f (a, k)) -> StateT f k k a
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((k -> f (a, k)) -> StateT f k k a)
-> (k -> f (a, k)) -> StateT f k k a
forall a b. (a -> b) -> a -> b
$ \ k :: k
k -> (a -> k -> (a, k)) -> k -> a -> (a, k)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) k
k (a -> (a, k)) -> f a -> f (a, k)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
xm

mapStateT :: (f (a, j) -> g (b, k)) -> StateT f i j a -> StateT g i k b
mapStateT :: (f (a, j) -> g (b, k)) -> StateT f i j a -> StateT g i k b
mapStateT f :: f (a, j) -> g (b, k)
f (StateT x :: i -> f (a, j)
x) = (i -> g (b, k)) -> StateT g i k b
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT (f (a, j) -> g (b, k)
f (f (a, j) -> g (b, k)) -> (i -> f (a, j)) -> i -> g (b, k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> f (a, j)
x)

modify :: Applicative p => (i -> j) -> StateT p i j i
modify :: (i -> j) -> StateT p i j i
modify f :: i -> j
f = (i -> p j) -> StateT p i j i
forall (f :: * -> *) i j. Functor f => (i -> f j) -> StateT f i j i
modifyF (j -> p j
forall (f :: * -> *) a. Applicative f => a -> f a
pure (j -> p j) -> (i -> j) -> i -> p j
forall b c a. (b -> c) -> (a -> b) -> a -> c
. i -> j
f)

modifyF :: Functor f => (i -> f j) -> StateT f i j i
modifyF :: (i -> f j) -> StateT f i j i
modifyF = (i -> f (i, j)) -> StateT f i j i
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((i -> f (i, j)) -> StateT f i j i)
-> ((i -> f j) -> i -> f (i, j)) -> (i -> f j) -> StateT f i j i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((j -> (i, j)) -> f j -> f (i, j))
-> (i -> j -> (i, j)) -> (i -> f j) -> i -> f (i, j)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
Base.liftA2 (j -> (i, j)) -> f j -> f (i, j)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (,)

get :: Applicative p => StateT p k k k
get :: StateT p k k k
get = (k -> p k) -> StateT p k k k
forall (f :: * -> *) i j. Functor f => (i -> f j) -> StateT f i j i
modifyF k -> p k
forall (f :: * -> *) a. Applicative f => a -> f a
pure

put :: Applicative p => j -> StateT p i j ()
put :: j -> StateT p i j ()
put = (i -> p ((), j)) -> StateT p i j ()
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((i -> p ((), j)) -> StateT p i j ())
-> (j -> i -> p ((), j)) -> j -> StateT p i j ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. p ((), j) -> i -> p ((), j)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (p ((), j) -> i -> p ((), j))
-> (j -> p ((), j)) -> j -> i -> p ((), j)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), j) -> p ((), j)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (((), j) -> p ((), j)) -> (j -> ((), j)) -> j -> p ((), j)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,) ()

instance Base.Monad m => Apply (StateT m) where
    StateT fm :: i -> m (a -> b, j)
fm <*> :: StateT m i j (a -> b) -> StateT m j k a -> StateT m i k b
<*> StateT xm :: j -> m (a, k)
xm = (i -> m (b, k)) -> StateT m i k b
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((i -> m (b, k)) -> StateT m i k b)
-> (i -> m (b, k)) -> StateT m i k b
forall a b. (a -> b) -> a -> b
$ \ i :: i
i -> [(a -> b
f a
x, k
k) | (f :: a -> b
f, j :: j
j) <- i -> m (a -> b, j)
fm i
i, (x :: a
x, k :: k
k) <- j -> m (a, k)
xm j
j]

instance Base.Monad m => Bind (StateT m) where
    join :: StateT m i j (StateT m j k a) -> StateT m i k a
join = (i -> m (a, k)) -> StateT m i k a
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((i -> m (a, k)) -> StateT m i k a)
-> (StateT m i j (StateT m j k a) -> i -> m (a, k))
-> StateT m i j (StateT m j k a)
-> StateT m i k a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((i -> m (StateT m j k a, j))
-> ((StateT m j k a, j) -> m (a, k)) -> i -> m (a, k)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
Base.>=> (StateT m j k a -> j -> m (a, k))
-> (StateT m j k a, j) -> m (a, k)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry StateT m j k a -> j -> m (a, k)
forall (f :: * -> *) i j a. StateT f i j a -> i -> f (a, j)
runStateT) ((i -> m (StateT m j k a, j)) -> i -> m (a, k))
-> (StateT m i j (StateT m j k a) -> i -> m (StateT m j k a, j))
-> StateT m i j (StateT m j k a)
-> i
-> m (a, k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StateT m i j (StateT m j k a) -> i -> m (StateT m j k a, j)
forall (f :: * -> *) i j a. StateT f i j a -> i -> f (a, j)
runStateT

instance Base.Monad m => Base.Applicative (StateT m k k) where
    pure :: a -> StateT m k k a
pure a :: a
a = (k -> m (a, k)) -> StateT m k k a
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((k -> m (a, k)) -> StateT m k k a)
-> (k -> m (a, k)) -> StateT m k k a
forall a b. (a -> b) -> a -> b
$ (a, k) -> m (a, k)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((a, k) -> m (a, k)) -> (k -> (a, k)) -> k -> m (a, k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,) a
a
    <*> :: StateT m k k (a -> b) -> StateT m k k a -> StateT m k k b
(<*>) = StateT m k k (a -> b) -> StateT m k k a -> StateT m k k b
forall k (p :: k -> k -> * -> *) (i :: k) (j :: k) a b (k :: k).
Apply p =>
p i j (a -> b) -> p j k a -> p i k b
(<*>)

instance Base.Monad m => Base.Monad (StateT m k k) where
    >>= :: StateT m k k a -> (a -> StateT m k k b) -> StateT m k k b
(>>=) = StateT m k k a -> (a -> StateT m k k b) -> StateT m k k b
forall k (m :: k -> k -> * -> *) (i :: k) (j :: k) a (k :: k) b.
Bind m =>
m i j a -> (a -> m j k b) -> m i k b
(>>=)

instance Base.MonadPlus m => Alternative (StateT m k k) where
    empty :: StateT m k k a
empty = (k -> m (a, k)) -> StateT m k k a
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT (m (a, k) -> k -> m (a, k)
forall (f :: * -> *) a. Applicative f => a -> f a
pure m (a, k)
forall (f :: * -> *) a. Alternative f => f a
empty)
    StateT a :: k -> m (a, k)
a <|> :: StateT m k k a -> StateT m k k a -> StateT m k k a
<|> StateT b :: k -> m (a, k)
b = (k -> m (a, k)) -> StateT m k k a
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((m (a, k) -> m (a, k) -> m (a, k))
-> (k -> m (a, k)) -> (k -> m (a, k)) -> k -> m (a, k)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
Base.liftA2 m (a, k) -> m (a, k) -> m (a, k)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>) k -> m (a, k)
a k -> m (a, k)
b)

instance Base.MonadPlus m => Base.MonadPlus (StateT m k k) where
    mzero :: StateT m k k a
mzero = StateT m k k a
forall (f :: * -> *) a. Alternative f => f a
empty
    mplus :: StateT m k k a -> StateT m k k a -> StateT m k k a
mplus = StateT m k k a -> StateT m k k a -> StateT m k k a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>)

instance Base.MonadFix m => Base.MonadFix (StateT m k k) where
    mfix :: (a -> StateT m k k a) -> StateT m k k a
mfix f :: a -> StateT m k k a
f = (k -> m (a, k)) -> StateT m k k a
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((k -> m (a, k)) -> StateT m k k a)
-> (k -> m (a, k)) -> StateT m k k a
forall a b. (a -> b) -> a -> b
$ ((a, k) -> m (a, k)) -> m (a, k)
forall (m :: * -> *) a. MonadFix m => (a -> m a) -> m a
Base.mfix (((a, k) -> m (a, k)) -> m (a, k))
-> (k -> (a, k) -> m (a, k)) -> k -> m (a, k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \ k :: k
k -> (StateT m k k a -> k -> m (a, k))
-> k -> StateT m k k a -> m (a, k)
forall a b c. (a -> b -> c) -> b -> a -> c
flip StateT m k k a -> k -> m (a, k)
forall (f :: * -> *) i j a. StateT f i j a -> i -> f (a, j)
runStateT k
k (StateT m k k a -> m (a, k))
-> ((a, k) -> StateT m k k a) -> (a, k) -> m (a, k)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> StateT m k k a
f (a -> StateT m k k a) -> ((a, k) -> a) -> (a, k) -> StateT m k k a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, k) -> a
forall a b. (a, b) -> a
fst

liftCallCC
 :: CallCC f g h (a, i) (b, j) (c, k) (d, l)
 -> CallCC (StateT f e j) (StateT g i k) (StateT h i l) a b c d
liftCallCC :: CallCC f g h (a, i) (b, j) (c, k) (d, l)
-> CallCC (StateT f e j) (StateT g i k) (StateT h i l) a b c d
liftCallCC callCC :: CallCC f g h (a, i) (b, j) (c, k) (d, l)
callCC f :: (a -> StateT f e j b) -> StateT g i k c
f =
    (i -> h (d, l)) -> StateT h i l d
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((i -> h (d, l)) -> StateT h i l d)
-> (i -> h (d, l)) -> StateT h i l d
forall a b. (a -> b) -> a -> b
$ \ st :: i
st ->
    CallCC f g h (a, i) (b, j) (c, k) (d, l)
callCC CallCC f g h (a, i) (b, j) (c, k) (d, l)
-> CallCC f g h (a, i) (b, j) (c, k) (d, l)
forall a b. (a -> b) -> a -> b
$ \ k :: (a, i) -> f (b, j)
k ->
    StateT g i k c -> i -> g (c, k)
forall (f :: * -> *) i j a. StateT f i j a -> i -> f (a, j)
runStateT ((a -> StateT f e j b) -> StateT g i k c
f ((a -> StateT f e j b) -> StateT g i k c)
-> (a -> StateT f e j b) -> StateT g i k c
forall a b. (a -> b) -> a -> b
$ \ a :: a
a -> (e -> f (b, j)) -> StateT f e j b
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((e -> f (b, j)) -> StateT f e j b)
-> (e -> f (b, j)) -> StateT f e j b
forall a b. (a -> b) -> a -> b
$ \ _ -> (a, i) -> f (b, j)
k (a
a, i
st)) i
st

liftCatch
 :: Catch e f g h (a, i) (b, j) (c, k)
 -> Catch e (StateT f l i) (StateT g l j) (StateT h l k) a b c
liftCatch :: Catch e f g h (a, i) (b, j) (c, k)
-> Catch e (StateT f l i) (StateT g l j) (StateT h l k) a b c
liftCatch catchE :: Catch e f g h (a, i) (b, j) (c, k)
catchE (StateT xm :: l -> f (a, i)
xm) h :: e -> StateT g l j b
h = (l -> h (c, k)) -> StateT h l k c
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((l -> h (c, k)) -> StateT h l k c)
-> (l -> h (c, k)) -> StateT h l k c
forall a b. (a -> b) -> a -> b
$ \ st :: l
st -> l -> f (a, i)
xm l
st Catch e f g h (a, i) (b, j) (c, k)
`catchE` \ e :: e
e -> StateT g l j b -> l -> g (b, j)
forall (f :: * -> *) i j a. StateT f i j a -> i -> f (a, j)
runStateT (e -> StateT g l j b
h e
e) l
st

liftListen
 :: Functor f
 => Listen w f (a, j) b
 -> Listen w (StateT f i j) a (i -> b)
liftListen :: Listen w f (a, j) b -> Listen w (StateT f i j) a (i -> b)
liftListen listen :: Listen w f (a, j) b
listen xm :: i -> b
xm = (i -> f ((a, w), j)) -> StateT f i j (a, w)
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((i -> f ((a, w), j)) -> StateT f i j (a, w))
-> (i -> f ((a, w), j)) -> StateT f i j (a, w)
forall a b. (a -> b) -> a -> b
$ \ st :: i
st ->
    ((((a, j), w) -> ((a, w), j)) -> f ((a, j), w) -> f ((a, w), j))
-> f ((a, j), w) -> (((a, j), w) -> ((a, w), j)) -> f ((a, w), j)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (((a, j), w) -> ((a, w), j)) -> f ((a, j), w) -> f ((a, w), j)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Listen w f (a, j) b
listen (i -> b
xm i
st)) ((((a, j), w) -> ((a, w), j)) -> f ((a, w), j))
-> (((a, j), w) -> ((a, w), j)) -> f ((a, w), j)
forall a b. (a -> b) -> a -> b
$ \ ~((a :: a
a, st' :: j
st'), w :: w
w) -> ((a
a, w
w), j
st')

liftPass
 :: Functor f
 => Pass z f g (a, k) (b, j)
 -> Pass z (StateT f i k) (StateT g i j) a b
liftPass :: Pass z f g (a, k) (b, j)
-> Pass z (StateT f i k) (StateT g i j) a b
liftPass pass :: Pass z f g (a, k) (b, j)
pass (StateT xm :: i -> f ((a, z), k)
xm) = (i -> g (b, j)) -> StateT g i j b
forall (f :: * -> *) i j a. (i -> f (a, j)) -> StateT f i j a
StateT ((i -> g (b, j)) -> StateT g i j b)
-> (i -> g (b, j)) -> StateT g i j b
forall a b. (a -> b) -> a -> b
$ \ st :: i
st ->
    Pass z f g (a, k) (b, j)
pass Pass z f g (a, k) (b, j) -> Pass z f g (a, k) (b, j)
forall a b. (a -> b) -> a -> b
$ ((((a, z), k) -> ((a, k), z)) -> f ((a, z), k) -> f ((a, k), z))
-> f ((a, z), k) -> (((a, z), k) -> ((a, k), z)) -> f ((a, k), z)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (((a, z), k) -> ((a, k), z)) -> f ((a, z), k) -> f ((a, k), z)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (i -> f ((a, z), k)
xm i
st) ((((a, z), k) -> ((a, k), z)) -> f ((a, k), z))
-> (((a, z), k) -> ((a, k), z)) -> f ((a, k), z)
forall a b. (a -> b) -> a -> b
$ \ ~((a :: a
a, f :: z
f), st' :: k
st') -> ((a
a, k
st'), z
f)