{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# EXT InlineAll #-}
module Control.Monad.State.Layered where
import Prelude hiding (Monad)
import Control.Applicative
import Control.Lens.Utils hiding (Getter, Setter, Setter')
import Control.Monad.Branch
import Control.Monad.Catch
import Control.Monad.Fail
import Control.Monad.Identity hiding (Monad)
import Control.Monad.IO.Class
import Control.Monad.Primitive
import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.Constraint
import Data.Default
import Data.Kind (Type)
import Type.Bool
import qualified Prelude as P
import qualified Control.Monad.State.Strict as S
type M = P.Monad
type State s = StateT s Identity
newtype StateT s m a = StateT (S.StateT s m a)
deriving ( Applicative, Alternative, Functor, M, MonadFail, MonadFix
, MonadIO, MonadPlus, MonadTrans, MonadThrow, MonadBranch )
makeWrapped ''StateT
type States ss = StatesT ss Identity
type family StatesT ss m where
StatesT '[] m = m
StatesT (s ': ss) m = StateT s (StatesT ss m)
type InferStateData (l :: k) m = InferStateData' k l m
type family InferStateData' k (l :: k) m where
InferStateData' Type l m = l
InferStateData' k l m = StateData l m
type TransStateData l t m = (InferStateData l (t m) ~ InferStateData l m)
type family StateData l m where
StateData l (StateT s m) = If (MatchedBases l s) s (StateData l m)
StateData l (t m) = StateData l m
type family MatchedBases (a :: ka) (b :: kb) :: Bool where
MatchedBases (a :: k) (b :: k) = a == b
MatchedBases (a :: k) (b t :: l) = MatchedBases a b
MatchedBases (a :: k) (b :: l) = 'False
type Monad l m = (Getter l m, Setter l m)
class M m => Getter l m where get :: m (InferStateData l m)
class M m => Setter l m where put :: InferStateData l m -> m ()
instance M m => Getter (l :: Type) (StateT l m) where get = wrap S.get ; {-# INLINE get #-}
instance M m => Setter (l :: Type) (StateT l m) where put a = wrap $ S.put a ; {-# INLINE put #-}
instance {-# OVERLAPPABLE #-} Getter l m => Getter (l :: Type) (StateT s m) where get = lift $ get @l ; {-# INLINE get #-}
instance {-# OVERLAPPABLE #-} Setter l m => Setter (l :: Type) (StateT s m) where put a = lift $ put @l a ; {-# INLINE put #-}
instance {-# OVERLAPPABLE #-} (M m, Getter__ ok l (StateT s m), ok ~ MatchedBases l s) => Getter (l :: k) (StateT s m) where get = get__ @ok @l ; {-# INLINE get #-}
instance {-# OVERLAPPABLE #-} (M m, Setter__ ok l (StateT s m), ok ~ MatchedBases l s) => Setter (l :: k) (StateT s m) where put a = put__ @ok @l a ; {-# INLINE put #-}
instance {-# OVERLAPPABLE #-} (M (t m), MonadTrans t, Getter l m, TransStateData l t m) => Getter (l :: k) (t m) where get = lift $ get @l ; {-# INLINE get #-}
instance {-# OVERLAPPABLE #-} (M (t m), MonadTrans t, Setter l m, TransStateData l t m) => Setter (l :: k) (t m) where put a = lift $ put @l a ; {-# INLINE put #-}
class M m => Getter__ (ok :: Bool) l m where get__ :: m (InferStateData l m)
class M m => Setter__ (ok :: Bool) l m where put__ :: InferStateData l m -> m ()
instance (M m, InferStateData l (StateT s m) ~ s) => Getter__ 'True l (StateT s m) where get__ = get @s ; {-# INLINE get__ #-}
instance (M m, InferStateData l (StateT s m) ~ s) => Setter__ 'True l (StateT s m) where put__ a = put @s a ; {-# INLINE put__ #-}
instance (Getter l m, TransStateData l (StateT s) m) => Getter__ 'False l (StateT s m) where get__ = lift $ get @l ; {-# INLINE get__ #-}
instance (Setter l m, TransStateData l (StateT s) m) => Setter__ 'False l (StateT s m) where put__ a = lift $ put @l a ; {-# INLINE put__ #-}
type MonadStates ss m = (Getters ss m, Setters ss m)
type Getters ss m = Monads__ Getter ss m
type Setters ss m = Monads__ Setter ss m
type family Monads__ p ss m :: Constraint where
Monads__ p (s ': ss) m = (p s m, Monads__ p ss m)
Monads__ p '[] m = ()
gets :: ∀ l m s a. (Getter l m, s ~ InferStateData l m)
=> Lens' s a -> m a
gets l = view l <$> get @l ; {-# INLINE gets #-}
type family TopStateData m where
TopStateData (StateT s m) = s
TopStateData (t m) = TopStateData m
type Monad' m = (Getter' m, Setter' m)
type Getter' m = Getter (TopStateData m) m
type Setter' m = Setter (TopStateData m) m
get' :: ∀ m. Getter' m => m (TopStateData m)
put' :: ∀ m. Setter' m => TopStateData m -> m ()
get' = get @(TopStateData m) ; {-# INLINE get' #-}
put' = put @(TopStateData m) ; {-# INLINE put' #-}
gets' :: ∀ m s a. (Getter' m, s ~ TopStateData m) => Lens' s a -> m a
gets' l = view l <$> get' ; {-# INLINE gets' #-}
stateT :: (s -> m (a,s)) -> StateT s m a
stateT = StateT . S.StateT ; {-# INLINE stateT #-}
runT :: ∀ s m a. StateT s m a -> s -> m (a, s)
evalT :: ∀ s m a. S.Monad m => StateT s m a -> s -> m a
execT :: ∀ s m a. S.Monad m => StateT s m a -> s -> m s
runT = S.runStateT . unwrap ; {-# INLINE runT #-}
evalT !m !s = let fst (!a, !_) = a in fst <$> runT m s ; {-# INLINE evalT #-}
execT !m !s = let snd (!_, !a) = a in snd <$> runT m s ; {-# INLINE execT #-}
runDefT :: ∀ s m a. Default s => StateT s m a -> m (a, s)
evalDefT :: ∀ s m a. (S.Monad m,Default s) => StateT s m a -> m a
execDefT :: ∀ s m a. (S.Monad m,Default s) => StateT s m a -> m s
runDefT !s = runT s def ; {-# INLINE runDefT #-}
evalDefT !s = evalT s def ; {-# INLINE evalDefT #-}
execDefT !s = execT s def ; {-# INLINE execDefT #-}
run :: ∀ s a. State s a -> s -> (a, s)
eval :: ∀ s a. State s a -> s -> a
exec :: ∀ s a. State s a -> s -> s
run = S.runState . unwrap ; {-# INLINE run #-}
eval = S.evalState . unwrap ; {-# INLINE eval #-}
exec = S.execState . unwrap ; {-# INLINE exec #-}
runDef :: ∀ s a. Default s => State s a -> (a, s)
evalDef :: ∀ s a. Default s => State s a -> a
execDef :: ∀ s a. Default s => State s a -> s
runDef = flip run def ; {-# INLINE runDef #-}
evalDef = flip eval def ; {-# INLINE evalDef #-}
execDef = flip exec def ; {-# INLINE execDef #-}
type InferMonadState s l m = (Monad l m, s ~ InferStateData l m)
modifyM :: ∀ l s m a. InferMonadState s l m => (s -> m (a, s)) -> m a
modifyM_ :: ∀ l s m a. InferMonadState s l m => (s -> m s) -> m ()
modify :: ∀ l s m a. InferMonadState s l m => (s -> (a, s)) -> m a
modify_ :: ∀ l s m a. InferMonadState s l m => (s -> s) -> m ()
modify = modifyM @l . fmap return ; {-# INLINE modify #-}
modify_ = modifyM_ @l . fmap return ; {-# INLINE modify_ #-}
modifyM_ = modifyM @l . (fmap.fmap) ((),) ; {-# INLINE modifyM_ #-}
modifyM f = do (!a,!t) <- f =<< get @l
a <$ put @l t
{-# INLINE modifyM #-}
sub :: ∀ l s m a. InferMonadState s l m => m a -> m a
with :: ∀ l s m a. InferMonadState s l m => s -> m a -> m a
withModified :: ∀ l s m a. InferMonadState s l m => (s -> s) -> m a -> m a
withModifiedM :: ∀ l s m a. InferMonadState s l m => (s -> m s) -> m a -> m a
with = withModified @l . const ; {-# INLINE with #-}
withModified = withModifiedM @l . fmap return ; {-# INLINE withModified #-}
withModifiedM f m = sub @l $ modifyM_ @l f >> m ; {-# INLINE withModifiedM #-}
sub m = do s <- get @l
m <* put @l s
{-#INLINE sub #-}
type TopMonadState s m = (Monad' m, s ~ TopStateData m)
modifyM' :: ∀ s m a. TopMonadState s m => (s -> m (a, s)) -> m a
modifyM'_ :: ∀ s m a. TopMonadState s m => (s -> m s) -> m ()
modify' :: ∀ s m a. TopMonadState s m => (s -> (a, s)) -> m a
modify'_ :: ∀ s m a. TopMonadState s m => (s -> s) -> m ()
modify' = modifyM' . fmap return ; {-# INLINE modify' #-}
modify'_ = modifyM'_ . fmap return ; {-# INLINE modify'_ #-}
modifyM'_ = modifyM' . (fmap.fmap) ((),) ; {-# INLINE modifyM'_ #-}
modifyM' f = do (!a,!t) <- f =<< get'
a <$ put' t
{-# INLINE modifyM' #-}
sub' :: ∀ s m a. TopMonadState s m => m a -> m a
with' :: ∀ s m a. TopMonadState s m => s -> m a -> m a
withModified' :: ∀ s m a. TopMonadState s m => (s -> s) -> m a -> m a
withModifiedM' :: ∀ s m a. TopMonadState s m => (s -> m s) -> m a -> m a
with' = withModified' . const ; {-# INLINE with' #-}
withModified' = withModifiedM' . fmap return ; {-# INLINE withModified' #-}
withModifiedM' f m = sub' $ modifyM'_ f >> m ; {-# INLINE withModifiedM' #-}
sub' m = do s <- get'
m <* put' s
{-# INLINE sub' #-}
mapT :: (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapT f = _Wrapped %~ S.mapStateT f ; {-# INLINE mapT #-}
instance PrimMonad m => PrimMonad (StateT s m) where
type PrimState (StateT s m) = PrimState m
primitive = lift . primitive ; {-# INLINE primitive #-}