module Data.Profunctor.State where

import Data.Profunctor
import Control.Category (Category)
import qualified Control.Category as C
import Data.Bifunctor (first)
import Data.Profunctor.State.Class

newtype StateT s p a b = StateT (p (a, s) (b, s))

instance Profunctor p => Profunctor (StateT s p) where
  dimap :: (a -> b) -> (c -> d) -> StateT s p b c -> StateT s p a d
dimap f :: a -> b
f g :: c -> d
g (StateT s :: p (b, s) (c, s)
s) = p (a, s) (d, s) -> StateT s p a d
forall s (p :: * -> * -> *) a b. p (a, s) (b, s) -> StateT s p a b
StateT (((a, s) -> (b, s))
-> ((c, s) -> (d, s)) -> p (b, s) (c, s) -> p (a, s) (d, s)
forall (p :: * -> * -> *) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap ((a -> b) -> (a, s) -> (b, s)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first a -> b
f) ((c -> d) -> (c, s) -> (d, s)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first c -> d
g) p (b, s) (c, s)
s)

instance (Category p) => Category (StateT s p) where
  id :: StateT s p a a
id = p (a, s) (a, s) -> StateT s p a a
forall s (p :: * -> * -> *) a b. p (a, s) (b, s) -> StateT s p a b
StateT p (a, s) (a, s)
forall k (cat :: k -> k -> *) (a :: k). Category cat => cat a a
C.id
  StateT s :: p (b, s) (c, s)
s . :: StateT s p b c -> StateT s p a b -> StateT s p a c
. StateT t :: p (a, s) (b, s)
t = p (a, s) (c, s) -> StateT s p a c
forall s (p :: * -> * -> *) a b. p (a, s) (b, s) -> StateT s p a b
StateT (p (b, s) (c, s)
s p (b, s) (c, s) -> p (a, s) (b, s) -> p (a, s) (c, s)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
C.. p (a, s) (b, s)
t)

instance (Category p, Profunctor p) => ProfunctorState s (StateT s p) where
  state :: StateT s p (a, s) (b, s) -> StateT s p a b
state (StateT p :: p ((a, s), s) ((b, s), s)
p) = p (a, s) (b, s) -> StateT s p a b
forall s (p :: * -> * -> *) a b. p (a, s) (b, s) -> StateT s p a b
StateT (((a, s) -> ((a, s), s))
-> (((b, s), s) -> (b, s))
-> p ((a, s), s) ((b, s), s)
-> p (a, s) (b, s)
forall (p :: * -> * -> *) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap (\(a :: a
a, s :: s
s) -> ((a
a, s
s), s
s)) ((b, s), s) -> (b, s)
forall a b. (a, b) -> a
fst p ((a, s), s) ((b, s), s)
p)

instance (Profunctor p) => ProfunctorState' s (StateT s p) where
  get' :: StateT s p (a, s) b -> StateT s p a b
get' (StateT p :: p ((a, s), s) (b, s)
p) = p (a, s) (b, s) -> StateT s p a b
forall s (p :: * -> * -> *) a b. p (a, s) (b, s) -> StateT s p a b
StateT (((a, s) -> ((a, s), s)) -> p ((a, s), s) (b, s) -> p (a, s) (b, s)
forall (p :: * -> * -> *) a b c.
Profunctor p =>
(a -> b) -> p b c -> p a c
lmap (\(a :: a
a, s :: s
s) -> ((a
a, s
s), s
s)) p ((a, s), s) (b, s)
p)
  put' :: StateT s p a (b, s) -> StateT s p a b
put' (StateT p :: p (a, s) ((b, s), s)
p) = p (a, s) (b, s) -> StateT s p a b
forall s (p :: * -> * -> *) a b. p (a, s) (b, s) -> StateT s p a b
StateT ((((b, s), s) -> (b, s)) -> p (a, s) ((b, s), s) -> p (a, s) (b, s)
forall (p :: * -> * -> *) b c a.
Profunctor p =>
(b -> c) -> p a b -> p a c
rmap ((b, s), s) -> (b, s)
forall a b. (a, b) -> a
fst p (a, s) ((b, s), s)
p)