{-# OPTIONS -Wno-orphans #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS_HADDOCK hide #-}

module Control.Functor.Linear.Internal.State
  ( StateT (..),
    State,
    state,
    get,
    put,
    gets,
    modify,
    replace,
    runStateT,
    runState,
    mapStateT,
    mapState,
    evalStateT,
    evalState,
    execStateT,
    execState,
    withStateT,
    withState,
  )
where

import Control.Functor.Linear.Internal.Class
import Control.Functor.Linear.Internal.Instances (Data (..))
import Control.Functor.Linear.Internal.MonadTrans
import qualified Control.Monad as NonLinear ()
import qualified Control.Monad.Trans.State.Strict as NonLinear
import Data.Functor.Identity
import qualified Data.Functor.Linear.Internal.Applicative as Data
import qualified Data.Functor.Linear.Internal.Functor as Data
import qualified Data.Tuple.Linear as Linear
import Data.Unrestricted.Linear.Internal.Consumable
import Data.Unrestricted.Linear.Internal.Dupable
import Prelude.Linear.Internal

-- # StateT
-------------------------------------------------------------------------------

-- | A (strict) linear state monad transformer.
newtype StateT s m a = StateT (s %1 -> m (a, s))
  deriving (forall a. a -> StateT s m a
forall a b.
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
forall a b c.
(a %1 -> b %1 -> c)
-> StateT s m a %1 -> StateT s m b %1 -> StateT s m c
forall {s} {m :: * -> *}. Monad m => Functor (StateT s m)
forall s (m :: * -> *) a. Monad m => a -> StateT s m a
forall s (m :: * -> *) a b.
Monad m =>
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
forall s (m :: * -> *) a b c.
Monad m =>
(a %1 -> b %1 -> c)
-> StateT s m a %1 -> StateT s m b %1 -> StateT s m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a %1 -> b) %1 -> f a %1 -> f b)
-> (forall a b c. (a %1 -> b %1 -> c) -> f a %1 -> f b %1 -> f c)
-> Applicative f
liftA2 :: forall a b c.
(a %1 -> b %1 -> c)
-> StateT s m a %1 -> StateT s m b %1 -> StateT s m c
$cliftA2 :: forall s (m :: * -> *) a b c.
Monad m =>
(a %1 -> b %1 -> c)
-> StateT s m a %1 -> StateT s m b %1 -> StateT s m c
<*> :: forall a b.
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
$c<*> :: forall s (m :: * -> *) a b.
Monad m =>
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
pure :: forall a. a -> StateT s m a
$cpure :: forall s (m :: * -> *) a. Monad m => a -> StateT s m a
Data.Applicative) via Data (StateT s m)

-- We derive Data.Applicative and not Data.Functor since Data.Functor can use
-- weaker constraints on m than Control.Functor, while
-- Data.Applicative needs a Monad instance just like Control.Applicative.

type State s = StateT s Identity

get :: (Applicative m, Dupable s) => StateT s m s
get :: forall (m :: * -> *) s. (Applicative m, Dupable s) => StateT s m s
get = forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state forall a. Dupable a => a %1 -> (a, a)
dup

put :: (Applicative m, Consumable s) => s %1 -> StateT s m ()
put :: forall (m :: * -> *) s.
(Applicative m, Consumable s) =>
s %1 -> StateT s m ()
put = forall (f :: * -> *) a. (Functor f, Consumable a) => f a %1 -> f ()
Data.void forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. forall (m :: * -> *) s. Applicative m => s %1 -> StateT s m s
replace

gets :: (Applicative m, Dupable s) => (s %1 -> a) %1 -> StateT s m a
gets :: forall (m :: * -> *) s a.
(Applicative m, Dupable s) =>
(s %1 -> a) %1 -> StateT s m a
gets s %1 -> a
f = forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state ((\(s
s1, s
s2) -> (s %1 -> a
f s
s1, s
s2)) forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. forall a. Dupable a => a %1 -> (a, a)
dup)

runStateT :: StateT s m a %1 -> s %1 -> m (a, s)
runStateT :: forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT (StateT s %1 -> m (a, s)
f) = s %1 -> m (a, s)
f

state :: (Applicative m) => (s %1 -> (a, s)) %1 -> StateT s m a
state :: forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state s %1 -> (a, s)
f = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (forall (f :: * -> *) a. Applicative f => a %1 -> f a
pure forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. s %1 -> (a, s)
f)

runState :: State s a %1 -> s %1 -> (a, s)
runState :: forall s a. State s a %1 -> s %1 -> (a, s)
runState State s a
f = forall a (p :: Multiplicity). Identity a %p -> a
runIdentity' forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT State s a
f

mapStateT :: (m (a, s) %1 -> n (b, s)) %1 -> StateT s m a %1 -> StateT s n b
mapStateT :: forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) %1 -> n (b, s)) %1 -> StateT s m a %1 -> StateT s n b
mapStateT m (a, s) %1 -> n (b, s)
r (StateT s %1 -> m (a, s)
f) = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (m (a, s) %1 -> n (b, s)
r forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. s %1 -> m (a, s)
f)

withStateT :: (s %1 -> s) %1 -> StateT s m a %1 -> StateT s m a
withStateT :: forall s (m :: * -> *) a.
(s %1 -> s) %1 -> StateT s m a %1 -> StateT s m a
withStateT s %1 -> s
r (StateT s %1 -> m (a, s)
f) = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (s %1 -> m (a, s)
f forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. s %1 -> s
r)

execStateT :: (Functor m) => StateT s m () %1 -> s %1 -> m s
execStateT :: forall (m :: * -> *) s.
Functor m =>
StateT s m () %1 -> s %1 -> m s
execStateT StateT s m ()
f = forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap (\((), s
s) -> s
s) forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. (forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT StateT s m ()
f)

-- | Use with care!
--   This consumes the final state, so might be costly at runtime.
evalStateT :: (Functor m, Consumable s) => StateT s m a %1 -> s %1 -> m a
evalStateT :: forall (m :: * -> *) s a.
(Functor m, Consumable s) =>
StateT s m a %1 -> s %1 -> m a
evalStateT StateT s m a
f = forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap forall b a. Consumable b => (a, b) %1 -> a
Linear.fst forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT StateT s m a
f

mapState :: ((a, s) %1 -> (b, s)) %1 -> State s a %1 -> State s b
mapState :: forall a s b. ((a, s) %1 -> (b, s)) %1 -> State s a %1 -> State s b
mapState (a, s) %1 -> (b, s)
f = forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) %1 -> n (b, s)) %1 -> StateT s m a %1 -> StateT s n b
mapStateT (forall a. a -> Identity a
Identity forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. (a, s) %1 -> (b, s)
f forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. forall a (p :: Multiplicity). Identity a %p -> a
runIdentity')

withState :: (s %1 -> s) %1 -> State s a %1 -> State s a
withState :: forall s a. (s %1 -> s) %1 -> State s a %1 -> State s a
withState = forall s (m :: * -> *) a.
(s %1 -> s) %1 -> StateT s m a %1 -> StateT s m a
withStateT

execState :: State s () %1 -> s %1 -> s
execState :: forall s. State s () %1 -> s %1 -> s
execState State s ()
f = forall a (p :: Multiplicity). Identity a %p -> a
runIdentity' forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. forall (m :: * -> *) s.
Functor m =>
StateT s m () %1 -> s %1 -> m s
execStateT State s ()
f

-- | Use with care!
--   This consumes the final state, so might be costly at runtime.
evalState :: (Consumable s) => State s a %1 -> s %1 -> a
evalState :: forall s a. Consumable s => State s a %1 -> s %1 -> a
evalState State s a
f = forall a (p :: Multiplicity). Identity a %p -> a
runIdentity' forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. forall (m :: * -> *) s a.
(Functor m, Consumable s) =>
StateT s m a %1 -> s %1 -> m a
evalStateT State s a
f

modify :: (Applicative m) => (s %1 -> s) %1 -> StateT s m ()
modify :: forall (m :: * -> *) s.
Applicative m =>
(s %1 -> s) %1 -> StateT s m ()
modify s %1 -> s
f = forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ \s
s -> ((), s %1 -> s
f s
s)

-- TODO: add strict version of `modify`

-- | @replace s@ will replace the current state with the new given state, and
-- return the old state.
replace :: (Applicative m) => s %1 -> StateT s m s
replace :: forall (m :: * -> *) s. Applicative m => s %1 -> StateT s m s
replace s
s = forall (m :: * -> *) s a.
Applicative m =>
(s %1 -> (a, s)) %1 -> StateT s m a
state forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ (\s
s' -> (s
s', s
s))

-- # Instances of StateT
-------------------------------------------------------------------------------

instance (Functor m) => Functor (NonLinear.StateT s m) where
  fmap :: forall a b. (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
fmap a %1 -> b
f (NonLinear.StateT s -> m (a, s)
x) = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
NonLinear.StateT forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ \s
s -> forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap (\(a
a, s
s') -> (a %1 -> b
f a
a, s
s')) forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ s -> m (a, s)
x s
s

instance (Data.Functor m) => Data.Functor (StateT s m) where
  fmap :: forall a b. (a %1 -> b) -> StateT s m a %1 -> StateT s m b
fmap a %1 -> b
f (StateT s %1 -> m (a, s)
x) = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (\s
s -> forall (f :: * -> *) a b. Functor f => (a %1 -> b) -> f a %1 -> f b
Data.fmap (\(a
a, s
s') -> (a %1 -> b
f a
a, s
s')) (s %1 -> m (a, s)
x s
s))

instance (Functor m) => Functor (StateT s m) where
  fmap :: forall a b. (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
fmap a %1 -> b
f (StateT s %1 -> m (a, s)
x) = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (\s
s -> forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap (\(a
a, s
s') -> (a %1 -> b
f a
a, s
s')) (s %1 -> m (a, s)
x s
s))

instance (Monad m) => Applicative (StateT s m) where
  pure :: forall a. a %1 -> StateT s m a
pure a
x = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (\s
s -> forall (m :: * -> *) a. Monad m => a %1 -> m a
return (a
x, s
s))
  StateT s %1 -> m (a %1 -> b, s)
mf <*> :: forall a b.
StateT s m (a %1 -> b) %1 -> StateT s m a %1 -> StateT s m b
<*> StateT s %1 -> m (a, s)
mx = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ \s
s -> do
    (a %1 -> b
f, s
s') <- s %1 -> m (a %1 -> b, s)
mf s
s
    (a
x, s
s'') <- s %1 -> m (a, s)
mx s
s'
    forall (m :: * -> *) a. Monad m => a %1 -> m a
return (a %1 -> b
f a
x, s
s'')

instance (Monad m) => Monad (StateT s m) where
  StateT s %1 -> m (a, s)
mx >>= :: forall a b.
StateT s m a %1 -> (a %1 -> StateT s m b) %1 -> StateT s m b
>>= a %1 -> StateT s m b
f = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ \s
s -> do
    (a
x, s
s') <- s %1 -> m (a, s)
mx s
s
    forall s (m :: * -> *) a. StateT s m a %1 -> s %1 -> m (a, s)
runStateT (a %1 -> StateT s m b
f a
x) s
s'

instance MonadTrans (StateT s) where
  lift :: forall (m :: * -> *) a. Monad m => m a %1 -> StateT s m a
lift m a
x = forall s (m :: * -> *) a. (s %1 -> m (a, s)) -> StateT s m a
StateT (\s
s -> forall (f :: * -> *) a b.
Functor f =>
(a %1 -> b) %1 -> f a %1 -> f b
fmap (,s
s) m a
x)