module Control.Monad.Ology.Data.Ref where

import Control.Monad.Ology.Data.Param
import Control.Monad.Ology.Data.Prod
import Control.Monad.Ology.General
import Control.Monad.Ology.Specific.StateT
import qualified Control.Monad.ST.Lazy as Lazy
import qualified Control.Monad.ST.Strict as Strict
import Data.IORef
import qualified Data.STRef.Lazy as Lazy
import qualified Data.STRef.Strict as Strict
import Import

-- | A reference of a monad (as in 'StateT').
data Ref m a = MkRef
    { forall (m :: Type -> Type) a. Ref m a -> m a
refGet :: m a
    , forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut :: a -> m ()
    }

instance Functor m => Invariant (Ref m) where
    invmap :: forall a b. (a -> b) -> (b -> a) -> Ref m a -> Ref m b
invmap a -> b
f b -> a
g (MkRef m a
gt a -> m ()
pt) = forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f m a
gt) (a -> m ()
pt forall {k} (cat :: k -> k -> Type) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. b -> a
g)

instance Applicative m => Productable (Ref m) where
    rUnit :: Ref m ()
rUnit = forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()) (\()
_ -> forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ())
    Ref m a
ra <***> :: forall a b. Ref m a -> Ref m b -> Ref m (a, b)
<***> Ref m b
rb = forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (forall (f :: Type -> Type) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) (forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ra) (forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m b
rb)) forall a b. (a -> b) -> a -> b
$ \(a
a, b
b) -> forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ra a
a forall (f :: Type -> Type) a b. Applicative f => f a -> f b -> f b
*> forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m b
rb b
b

refModify :: Monad m => Ref m a -> (a -> a) -> m ()
refModify :: forall (m :: Type -> Type) a.
Monad m =>
Ref m a -> (a -> a) -> m ()
refModify Ref m a
ref a -> a
f = do
    a
a <- forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
    forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref forall a b. (a -> b) -> a -> b
$ a -> a
f a
a

refModifyM :: Monad m => Ref m a -> (a -> m a) -> m ()
refModifyM :: forall (m :: Type -> Type) a.
Monad m =>
Ref m a -> (a -> m a) -> m ()
refModifyM Ref m a
ref a -> m a
f = do
    a
a <- forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
    a
a' <- a -> m a
f a
a
    forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref a
a'

-- | Restore the original value of this reference after the operation.
refRestore :: MonadException m => Ref m a -> m --> m
refRestore :: forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> m --> m
refRestore Ref m a
ref m a
mr = forall (m :: Type -> Type) a b.
MonadException m =>
m a -> (a -> m ()) -> (a -> m b) -> m b
bracketNoMask (forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref) (forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref) forall a b. (a -> b) -> a -> b
$ \a
_ -> m a
mr

lensMapRef ::
       forall m a b. Monad m
    => Lens' a b
    -> Ref m a
    -> Ref m b
lensMapRef :: forall (m :: Type -> Type) a b.
Monad m =>
Lens' a b -> Ref m a -> Ref m b
lensMapRef Lens' a b
l Ref m a
ref = let
    refGet' :: m b
refGet' = forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a
a -> forall {k} a (b :: k). Const a b -> a
getConst forall a b. (a -> b) -> a -> b
$ Lens' a b
l forall {k} a (b :: k). a -> Const a b
Const a
a) forall a b. (a -> b) -> a -> b
$ forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
    refPut' :: b -> m ()
refPut' b
b = do
        a
a <- forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
        forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref forall a b. (a -> b) -> a -> b
$ forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ Lens' a b
l (\b
_ -> forall a. a -> Identity a
Identity b
b) a
a
    in forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef m b
refGet' b -> m ()
refPut'

liftRef :: (MonadTrans t, Monad m) => Ref m --> Ref (t m)
liftRef :: forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type).
(MonadTrans t, Monad m) =>
Ref m --> Ref (t m)
liftRef (MkRef m a
g a -> m ()
m) = forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m a
g) forall a b. (a -> b) -> a -> b
$ \a
a -> forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ a -> m ()
m a
a

stateRef :: Monad m => Ref (StateT s m) s
stateRef :: forall (m :: Type -> Type) s. Monad m => Ref (StateT s m) s
stateRef = forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef forall (m :: Type -> Type) s. Monad m => StateT s m s
get forall (m :: Type -> Type) s. Monad m => s -> StateT s m ()
put

-- | Run a state monad over this reference.
refRunState :: Monad m => Ref m s -> StateT s m --> m
refRunState :: forall (m :: Type -> Type) s.
Monad m =>
Ref m s -> StateT s m --> m
refRunState Ref m s
ref StateT s m a
sm = do
    s
olds <- forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m s
ref
    (a
a, s
news) <- forall s (m :: Type -> Type) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
sm s
olds
    forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m s
ref s
news
    forall (m :: Type -> Type) a. Monad m => a -> m a
return a
a

ioRef :: IORef a -> Ref IO a
ioRef :: forall a. IORef a -> Ref IO a
ioRef IORef a
r = forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (forall a. IORef a -> IO a
readIORef IORef a
r) (forall a. IORef a -> a -> IO ()
writeIORef IORef a
r)

strictSTRef :: Strict.STRef s a -> Ref (Strict.ST s) a
strictSTRef :: forall s a. STRef s a -> Ref (ST s) a
strictSTRef STRef s a
r = forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (forall s a. STRef s a -> ST s a
Strict.readSTRef STRef s a
r) (forall s a. STRef s a -> a -> ST s ()
Strict.writeSTRef STRef s a
r)

lazySTRef :: Lazy.STRef s a -> Ref (Lazy.ST s) a
lazySTRef :: forall s a. STRef s a -> Ref (ST s) a
lazySTRef STRef s a
r = forall (m :: Type -> Type) a. m a -> (a -> m ()) -> Ref m a
MkRef (forall s a. STRef s a -> ST s a
Lazy.readSTRef STRef s a
r) (forall s a. STRef s a -> a -> ST s ()
Lazy.writeSTRef STRef s a
r)

-- | Use a reference as a parameter.
refParam ::
       forall m a. MonadException m
    => Ref m a
    -> Param m a
refParam :: forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> Param m a
refParam Ref m a
ref = let
    paramAsk :: m a
paramAsk = forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
    paramWith :: a -> m --> m
    paramWith :: a -> m --> m
paramWith a
a m a
mr =
        forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> m --> m
refRestore Ref m a
ref forall a b. (a -> b) -> a -> b
$ do
            forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref a
a
            m a
mr
    in MkParam {m a
a -> m --> m
paramWith :: a -> m --> m
paramAsk :: m a
paramWith :: a -> m --> m
paramAsk :: m a
..}

-- | Use a reference as a product.
refProd ::
       forall m a. (MonadException m, Monoid a)
    => Ref m a
    -> Prod m a
refProd :: forall (m :: Type -> Type) a.
(MonadException m, Monoid a) =>
Ref m a -> Prod m a
refProd Ref m a
ref = let
    prodTell :: a -> m ()
prodTell a
a = forall (m :: Type -> Type) a.
Monad m =>
Ref m a -> (a -> a) -> m ()
refModify Ref m a
ref forall a b. (a -> b) -> a -> b
$ forall a. Semigroup a => a -> a -> a
(<>) a
a
    prodCollect :: forall r. m r -> m (r, a)
    prodCollect :: forall r. m r -> m (r, a)
prodCollect m r
mr =
        forall (m :: Type -> Type) a.
MonadException m =>
Ref m a -> m --> m
refRestore Ref m a
ref forall a b. (a -> b) -> a -> b
$ do
            forall (m :: Type -> Type) a. Ref m a -> a -> m ()
refPut Ref m a
ref forall a. Monoid a => a
mempty
            r
r <- m r
mr
            a
a <- forall (m :: Type -> Type) a. Ref m a -> m a
refGet Ref m a
ref
            forall (m :: Type -> Type) a. Monad m => a -> m a
return (r
r, a
a)
    in MkProd {a -> m ()
forall r. m r -> m (r, a)
prodCollect :: forall r. m r -> m (r, a)
prodTell :: a -> m ()
prodCollect :: forall r. m r -> m (r, a)
prodTell :: a -> m ()
..}