{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE Trustworthy #-}
module Control.Eff.State.Lazy where
import Control.Eff.Internal
import Control.Eff.Writer.Lazy
import Control.Eff.Reader.Lazy
import Data.OpenUnion
import Control.Monad.Base
import Control.Monad.Trans.Control
data State s v where
  Get :: State s s
  Put :: s -> State s ()
instance ( MonadBase m m
         , SetMember Lift (Lift m) r
         , MonadBaseControl m (Eff r)
         ) => MonadBaseControl m (Eff (State s ': r)) where
    type StM (Eff (State s ': r)) a = StM (Eff r) (a,s)
    liftBaseWith f = do s <- get
                        raise $ liftBaseWith $ \runInBase ->
                          f (runInBase . runState s)
    restoreM x = do (a, s :: s) <- raise (restoreM x)
                    put s
                    return a
{-# NOINLINE get #-}
get :: Member (State s) r => Eff r s
get = send Get
{-# RULES
  "get/bind" forall k. get >>= k = send Get >>= k
 #-}
{-# NOINLINE put #-}
put :: Member (State s) r => s -> Eff r ()
put s = send (Put s)
{-# RULES
  "put/bind"     forall k v. put v >>= k = send (Put v) >>= k
 #-}
{-# RULES
  "put/semibind" forall k v. put v >>  k = send (Put v) >>= (\() -> k)
 #-}
runState' :: s -> Eff (State s ': r) a -> Eff r (a, s)
runState' s =
  handle_relay_s s (\s0 x -> return (x,s0))
                   (\s0 sreq k -> case sreq of
                       Get    -> k s0 s0
                       Put s1 -> k s1 ())
runState :: s                     
         -> Eff (State s ': r) a  
         -> Eff r (a, s)          
runState s (Val x) = return (x,s)
runState s (E u q) = case decomp u of
  Right Get     -> runState s (q ^$ s)
  Right (Put s1) -> runState s1 (q ^$ ())
  Left  u1 -> E u1 (singleK (\x -> runState s (q ^$ x)))
modify :: (Member (State s) r) => (s -> s) -> Eff r ()
modify f = get >>= put . f
evalState :: s -> Eff (State s ': r) a -> Eff r a
evalState s = fmap fst . runState s
execState :: s -> Eff (State s ': r) a -> Eff r s
execState s = fmap snd . runState s
data TxState s = TxState
transactionState :: forall s r a. Member (State s) r =>
                    TxState s -> Eff r a -> Eff r a
transactionState _ m = do s <- get; loop s m
 where
   loop :: s -> Eff r a -> Eff r a
   loop s (Val x) = put s >> return x
   loop s (E (u::Union r b) q) = case prj u :: Maybe (State s b) of
     Just Get      -> loop s (q ^$ s)
     Just (Put s') -> loop s'(q ^$ ())
     _             -> E u (qComps q (loop s))
runStateR :: s -> Eff (Writer s ': Reader s ': r) a -> Eff r (a, s)
runStateR s m = loop s m
 where
   loop :: s -> Eff (Writer s ': Reader s ': r) a -> Eff r (a, s)
   loop s0 (Val x) = return (x,s0)
   loop s0 (E u q) = case decomp u of
     Right (Tell w) -> k w ()
     Left  u1  -> case decomp u1 of
       Right Ask -> k s0 s0
       Left u2 -> E u2 (singleK (k s0))
    where k x = qComp q (loop x)