{-# LANGUAGE TemplateHaskell #-}
module Polysemy.State
  ( 
    State (..)
    
  , get
  , gets
  , put
  , modify
    
  , runState
  , evalState
  , runLazyState
  , evalLazyState
  , runStateIORef
    
  , hoistStateIntoStateT
  ) where
import qualified Control.Monad.Trans.State as S
import           Data.IORef
import           Data.Tuple (swap)
import           Polysemy
import           Polysemy.Internal
import           Polysemy.Internal.Combinators
import           Polysemy.Internal.Union
data State s m a where
  Get :: State s m s
  Put :: s -> State s m ()
makeSem ''State
gets :: forall s a r. Member (State s) r => (s -> a) -> Sem r a
gets f = fmap f get
{-# INLINABLE gets #-}
modify :: Member (State s) r => (s -> s) -> Sem r ()
modify f = do
  s <- get
  put $ f s
{-# INLINABLE modify #-}
runState :: s -> Sem (State s ': r) a -> Sem r (s, a)
runState = stateful $ \case
  Get   -> \s -> pure (s, s)
  Put s -> const $ pure (s, ())
{-# INLINE[3] runState #-}
evalState :: s -> Sem (State s ': r) a -> Sem r a
evalState s = fmap snd . runState s
{-# INLINE evalState #-}
runLazyState :: s -> Sem (State s ': r) a -> Sem r (s, a)
runLazyState = lazilyStateful $ \case
  Get   -> \s -> pure (s, s)
  Put s -> const $ pure (s, ())
{-# INLINE[3] runLazyState #-}
evalLazyState :: s -> Sem (State s ': r) a -> Sem r a
evalLazyState s = fmap snd . runLazyState s
{-# INLINE evalLazyState #-}
runStateIORef
    :: forall s r a
     . Member (Embed IO) r
    => IORef s
    -> Sem (State s ': r) a
    -> Sem r a
runStateIORef ref = interpret $ \case
  Get   -> embed $ readIORef ref
  Put s -> embed $ writeIORef ref s
{-# INLINE runStateIORef #-}
hoistStateIntoStateT
    :: Sem (State s ': r) a
    -> S.StateT s (Sem r) a
hoistStateIntoStateT (Sem m) = m $ \u ->
  case decomp u of
    Left x -> S.StateT $ \s ->
      liftSem . fmap swap
              . weave (s, ())
                      (\(s', m') -> fmap swap
                                  $ S.runStateT m' s')
                      (Just . snd)
              $ hoist hoistStateIntoStateT x
    Right (Weaving Get z _ y _)     -> fmap (y . (<$ z)) $ S.get
    Right (Weaving (Put s) z _ y _) -> fmap (y . (<$ z)) $ S.put s
{-# INLINE hoistStateIntoStateT #-}
{-# RULES "runState/reinterpret"
   forall s e (f :: forall m x. e m x -> Sem (State s ': r) x).
     runState s (reinterpret f e) = stateful (\x s' -> runState s' $ f x) s e
     #-}
{-# RULES "runLazyState/reinterpret"
   forall s e (f :: forall m x. e m x -> Sem (State s ': r) x).
     runLazyState s (reinterpret f e) = lazilyStateful (\x s' -> runLazyState s' $ f x) s e
     #-}