{-# LANGUAGE BangPatterns             #-}
{-# LANGUAGE ConstraintKinds          #-}
{-# LANGUAGE DataKinds                #-}
{-# LANGUAGE DeriveGeneric            #-}
{-# LANGUAGE FlexibleInstances        #-}
{-# LANGUAGE MagicHash                #-}
{-# LANGUAGE MultiParamTypeClasses    #-}
{-# LANGUAGE PolyKinds                #-}
{-# LANGUAGE RankNTypes               #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies             #-}
{-# LANGUAGE UnboxedTuples            #-}

module Control.Lens.Mutable.Types where

import           Control.Lens.Lens       (ALens', cloneLens)
import           Control.Lens.Type       (Lens', LensLike')
import           Control.Monad.Primitive (PrimBase (..), PrimMonad (..))
import           Data.Kind               (Type)
import           GHC.Conc                (STM (..))
import           GHC.Exts                (RealWorld, State#)
import           GHC.Generics            (Generic)

-- | GHC implements different primitive operations, some of which cannot be
-- mixed together and some of which can only be run in certain contexts. In
-- particular, 'STM'-related primops cannot be run directly in the 'IO' monad.
-- However, this restriction is not represented at the bottom layer of the 'IO'
-- runtime which we need to wrap around and expose to users.
--
-- This data structure is our ad-hoc attempt to group together "compatible"
-- primops so that only lens representing compatible references can be composed
-- together, avoiding deadly segfaults.
--
-- See https://gitlab.haskell.org/ghc/ghc/blob/master/compiler/prelude/primops.txt.pp
--
-- See also https://github.com/haskell/primitive/issues/43#issuecomment-613771394
data PrimOpGroup = OpST | OpMVar | OpSTM
  deriving (ReadPrec [PrimOpGroup]
ReadPrec PrimOpGroup
Int -> ReadS PrimOpGroup
ReadS [PrimOpGroup]
(Int -> ReadS PrimOpGroup)
-> ReadS [PrimOpGroup]
-> ReadPrec PrimOpGroup
-> ReadPrec [PrimOpGroup]
-> Read PrimOpGroup
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [PrimOpGroup]
$creadListPrec :: ReadPrec [PrimOpGroup]
readPrec :: ReadPrec PrimOpGroup
$creadPrec :: ReadPrec PrimOpGroup
readList :: ReadS [PrimOpGroup]
$creadList :: ReadS [PrimOpGroup]
readsPrec :: Int -> ReadS PrimOpGroup
$creadsPrec :: Int -> ReadS PrimOpGroup
Read, Int -> PrimOpGroup -> ShowS
[PrimOpGroup] -> ShowS
PrimOpGroup -> String
(Int -> PrimOpGroup -> ShowS)
-> (PrimOpGroup -> String)
-> ([PrimOpGroup] -> ShowS)
-> Show PrimOpGroup
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PrimOpGroup] -> ShowS
$cshowList :: [PrimOpGroup] -> ShowS
show :: PrimOpGroup -> String
$cshow :: PrimOpGroup -> String
showsPrec :: Int -> PrimOpGroup -> ShowS
$cshowsPrec :: Int -> PrimOpGroup -> ShowS
Show, (forall x. PrimOpGroup -> Rep PrimOpGroup x)
-> (forall x. Rep PrimOpGroup x -> PrimOpGroup)
-> Generic PrimOpGroup
forall x. Rep PrimOpGroup x -> PrimOpGroup
forall x. PrimOpGroup -> Rep PrimOpGroup x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep PrimOpGroup x -> PrimOpGroup
$cfrom :: forall x. PrimOpGroup -> Rep PrimOpGroup x
Generic, PrimOpGroup -> PrimOpGroup -> Bool
(PrimOpGroup -> PrimOpGroup -> Bool)
-> (PrimOpGroup -> PrimOpGroup -> Bool) -> Eq PrimOpGroup
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PrimOpGroup -> PrimOpGroup -> Bool
$c/= :: PrimOpGroup -> PrimOpGroup -> Bool
== :: PrimOpGroup -> PrimOpGroup -> Bool
$c== :: PrimOpGroup -> PrimOpGroup -> Bool
Eq, Eq PrimOpGroup
Eq PrimOpGroup
-> (PrimOpGroup -> PrimOpGroup -> Ordering)
-> (PrimOpGroup -> PrimOpGroup -> Bool)
-> (PrimOpGroup -> PrimOpGroup -> Bool)
-> (PrimOpGroup -> PrimOpGroup -> Bool)
-> (PrimOpGroup -> PrimOpGroup -> Bool)
-> (PrimOpGroup -> PrimOpGroup -> PrimOpGroup)
-> (PrimOpGroup -> PrimOpGroup -> PrimOpGroup)
-> Ord PrimOpGroup
PrimOpGroup -> PrimOpGroup -> Bool
PrimOpGroup -> PrimOpGroup -> Ordering
PrimOpGroup -> PrimOpGroup -> PrimOpGroup
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: PrimOpGroup -> PrimOpGroup -> PrimOpGroup
$cmin :: PrimOpGroup -> PrimOpGroup -> PrimOpGroup
max :: PrimOpGroup -> PrimOpGroup -> PrimOpGroup
$cmax :: PrimOpGroup -> PrimOpGroup -> PrimOpGroup
>= :: PrimOpGroup -> PrimOpGroup -> Bool
$c>= :: PrimOpGroup -> PrimOpGroup -> Bool
> :: PrimOpGroup -> PrimOpGroup -> Bool
$c> :: PrimOpGroup -> PrimOpGroup -> Bool
<= :: PrimOpGroup -> PrimOpGroup -> Bool
$c<= :: PrimOpGroup -> PrimOpGroup -> Bool
< :: PrimOpGroup -> PrimOpGroup -> Bool
$c< :: PrimOpGroup -> PrimOpGroup -> Bool
compare :: PrimOpGroup -> PrimOpGroup -> Ordering
$ccompare :: PrimOpGroup -> PrimOpGroup -> Ordering
$cp1Ord :: Eq PrimOpGroup
Ord, Int -> PrimOpGroup
PrimOpGroup -> Int
PrimOpGroup -> [PrimOpGroup]
PrimOpGroup -> PrimOpGroup
PrimOpGroup -> PrimOpGroup -> [PrimOpGroup]
PrimOpGroup -> PrimOpGroup -> PrimOpGroup -> [PrimOpGroup]
(PrimOpGroup -> PrimOpGroup)
-> (PrimOpGroup -> PrimOpGroup)
-> (Int -> PrimOpGroup)
-> (PrimOpGroup -> Int)
-> (PrimOpGroup -> [PrimOpGroup])
-> (PrimOpGroup -> PrimOpGroup -> [PrimOpGroup])
-> (PrimOpGroup -> PrimOpGroup -> [PrimOpGroup])
-> (PrimOpGroup -> PrimOpGroup -> PrimOpGroup -> [PrimOpGroup])
-> Enum PrimOpGroup
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: PrimOpGroup -> PrimOpGroup -> PrimOpGroup -> [PrimOpGroup]
$cenumFromThenTo :: PrimOpGroup -> PrimOpGroup -> PrimOpGroup -> [PrimOpGroup]
enumFromTo :: PrimOpGroup -> PrimOpGroup -> [PrimOpGroup]
$cenumFromTo :: PrimOpGroup -> PrimOpGroup -> [PrimOpGroup]
enumFromThen :: PrimOpGroup -> PrimOpGroup -> [PrimOpGroup]
$cenumFromThen :: PrimOpGroup -> PrimOpGroup -> [PrimOpGroup]
enumFrom :: PrimOpGroup -> [PrimOpGroup]
$cenumFrom :: PrimOpGroup -> [PrimOpGroup]
fromEnum :: PrimOpGroup -> Int
$cfromEnum :: PrimOpGroup -> Int
toEnum :: Int -> PrimOpGroup
$ctoEnum :: Int -> PrimOpGroup
pred :: PrimOpGroup -> PrimOpGroup
$cpred :: PrimOpGroup -> PrimOpGroup
succ :: PrimOpGroup -> PrimOpGroup
$csucc :: PrimOpGroup -> PrimOpGroup
Enum, PrimOpGroup
PrimOpGroup -> PrimOpGroup -> Bounded PrimOpGroup
forall a. a -> a -> Bounded a
maxBound :: PrimOpGroup
$cmaxBound :: PrimOpGroup
minBound :: PrimOpGroup
$cminBound :: PrimOpGroup
Bounded)

-- | Lifted 'State#'. This is needed to interoperate lifted ("normal") types
-- and unlifted types (such as primitives), but it also gives us the chance to
-- restrict composition based on 'PrimOpGroup' which sadly isn't done in the
-- unlifted internal representation, though it could be.
type S :: PrimOpGroup -> Type -> Type
data S p s = S !(State# s)

-- | A lifted primitive state-transformer that interoperates with lens.
--
-- Specifically, this is a bare (unwrapped in @StateT@) state transition on a
-- lifted ("normal") state type.
--
-- To obtain one of these, you may apply a @'SLens' p s a@ to a bare state
-- transition, i.e. a function of type @(a -> (r, a))@.
type LST p s r = S p s -> (r, S p s)

-- | Convert an @'LST' p@ to some context @m@.
--
-- This is similar to 'PrimMonad' from the @primitives@ package except our
-- extra @p@ type-param helps us avoid accidentally mixing incompatible primops.
class FromLST p s m where
  stToM :: LST p s r -> m r

-- | Convert an @'LST' p@ to and from some context @m@.
--
-- This is similar to 'PrimBase' from the @primitives@ package except our extra
-- @p@ type-param helps us avoid accidentally mixing incompatible primops.

class FromLST p s m => IsoLST p s m where
  mToST :: m r -> LST p s r

instance (PrimMonad m, s ~ PrimState m) => FromLST 'OpST s m where
  stToM :: LST 'OpST s r -> m r
stToM LST 'OpST s r
st = (State# (PrimState m) -> (# State# (PrimState m), r #)) -> m r
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState m) -> (# State# (PrimState m), r #)) -> m r)
-> (State# (PrimState m) -> (# State# (PrimState m), r #)) -> m r
forall a b. (a -> b) -> a -> b
$ \State# (PrimState m)
s1# -> let !(r
a, S State# s
s2#) = LST 'OpST s r
st (State# s -> S 'OpST s
forall (p :: PrimOpGroup) s. State# s -> S p s
S State# s
State# (PrimState m)
s1#) in (# State# s
State# (PrimState m)
s2#, r
a #)
  {-# INLINE stToM #-}

instance (PrimBase m, s ~ PrimState m) => IsoLST 'OpST s m where
  mToST :: m r -> LST 'OpST s r
mToST m r
prim (S State# s
s1#) = let !(# State# s
s2#, r
a #) = m r -> State# (PrimState m) -> (# State# (PrimState m), r #)
forall (m :: * -> *) a.
PrimBase m =>
m a -> State# (PrimState m) -> (# State# (PrimState m), a #)
internal m r
prim State# s
State# (PrimState m)
s1# in (r
a, State# s -> S 'OpST s
forall (p :: PrimOpGroup) s. State# s -> S p s
S State# s
s2#)
  {-# INLINE mToST #-}

-- same as OpST, we just forcibly keep them apart to be safe
instance FromLST 'OpMVar RealWorld IO where
  stToM :: LST 'OpMVar RealWorld r -> IO r
stToM LST 'OpMVar RealWorld r
st = (State# (PrimState IO) -> (# State# (PrimState IO), r #)) -> IO r
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive ((State# (PrimState IO) -> (# State# (PrimState IO), r #)) -> IO r)
-> (State# (PrimState IO) -> (# State# (PrimState IO), r #))
-> IO r
forall a b. (a -> b) -> a -> b
$ \State# (PrimState IO)
s1# -> let !(r
a, S State# RealWorld
s2#) = LST 'OpMVar RealWorld r
st (State# RealWorld -> S 'OpMVar RealWorld
forall (p :: PrimOpGroup) s. State# s -> S p s
S State# RealWorld
State# (PrimState IO)
s1#) in (# State# RealWorld
State# (PrimState IO)
s2#, r
a #)
  {-# INLINE stToM #-}

instance IsoLST 'OpMVar RealWorld IO where
  mToST :: IO r -> LST 'OpMVar RealWorld r
mToST IO r
prim (S State# RealWorld
s1#) = let !(# State# RealWorld
s2#, r
a #) = IO r -> State# (PrimState IO) -> (# State# (PrimState IO), r #)
forall (m :: * -> *) a.
PrimBase m =>
m a -> State# (PrimState m) -> (# State# (PrimState m), a #)
internal IO r
prim State# RealWorld
State# (PrimState IO)
s1# in (r
a, State# RealWorld -> S 'OpMVar RealWorld
forall (p :: PrimOpGroup) s. State# s -> S p s
S State# RealWorld
s2#)
  {-# INLINE mToST #-}

instance FromLST 'OpSTM RealWorld STM where
  stToM :: LST 'OpSTM RealWorld r -> STM r
stToM LST 'OpSTM RealWorld r
st = (State# RealWorld -> (# State# RealWorld, r #)) -> STM r
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> STM a
STM ((State# RealWorld -> (# State# RealWorld, r #)) -> STM r)
-> (State# RealWorld -> (# State# RealWorld, r #)) -> STM r
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s1# -> let !(r
a, S State# RealWorld
s2#) = LST 'OpSTM RealWorld r
st (State# RealWorld -> S 'OpSTM RealWorld
forall (p :: PrimOpGroup) s. State# s -> S p s
S State# RealWorld
s1#) in (# State# RealWorld
s2#, r
a #)
  {-# INLINE stToM #-}

instance IsoLST 'OpSTM RealWorld STM where
  mToST :: STM r -> LST 'OpSTM RealWorld r
mToST (STM State# RealWorld -> (# State# RealWorld, r #)
state#) (S State# RealWorld
s1#) = let !(# State# RealWorld
s2#, r
a #) = State# RealWorld -> (# State# RealWorld, r #)
state# State# RealWorld
s1# in (r
a, State# RealWorld -> S 'OpSTM RealWorld
forall (p :: PrimOpGroup) s. State# s -> S p s
S State# RealWorld
s2#)
  {-# INLINE mToST #-}

-- | Convert an @'LST p@ from some monadic action @m@.
type MonadLST p s m = (FromLST p s m, Monad m)

-- | Representation of a mutable reference as a 'Lens''.
--
-- When the lens functor type-param is @(,) r@, then the output transition
-- function is of type @'LST' s r@. To use it as a monadic action e.g. to run
-- it, you'll need to first convert it using 'stToM'.
--
-- Again, in principle this ought not to be necessary, but the Haskell runtime
-- forces us to do this due to historical design decisions to hide necessary
-- details that seemed appropriate to hide at the time.
type SLens p s a = Lens' (S p s) a

-- | Representation of a mutable reference as a 'ALens''.
--
-- This type is useful if you need to store a lens in a container. To recover
-- the original type, pass it through 'Control.Lens.cloneLens'.
type ASLens p s a = ALens' (S p s) a

-- ** Convenience functions

-- These are all compositions of the basic functions above, provided for
-- convenience rather than necessity.

-- | Run a bare state transition on a lens in the monad for @p@.
--
-- The lens may be an @'SLens' p@ or any compositions of it with other optics,
-- including prisms and so forth.
runSLens :: FromLST p s m => LensLike' ((,) r) (S p s) a -> (a -> (r, a)) -> m r
runSLens :: LensLike' ((,) r) (S p s) a -> (a -> (r, a)) -> m r
runSLens = (LST p s r -> m r)
-> LensLike' ((,) r) (S p s) a -> (a -> (r, a)) -> m r
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap LST p s r -> m r
forall (p :: PrimOpGroup) s (m :: * -> *) r.
FromLST p s m =>
LST p s r -> m r
stToM
{-# INLINE runSLens #-}

-- | Run a bare state transition on an 'ALens'' in the monad for @p@.
runASLens :: FromLST p s m => ALens' (S p s) a -> (a -> (r, a)) -> m r
runASLens :: ALens' (S p s) a -> (a -> (r, a)) -> m r
runASLens = LensLike' ((,) r) (S p s) a -> (a -> (r, a)) -> m r
forall (p :: PrimOpGroup) s (m :: * -> *) r a.
FromLST p s m =>
LensLike' ((,) r) (S p s) a -> (a -> (r, a)) -> m r
runSLens (LensLike' ((,) r) (S p s) a -> (a -> (r, a)) -> m r)
-> (ALens' (S p s) a -> LensLike' ((,) r) (S p s) a)
-> ALens' (S p s) a
-> (a -> (r, a))
-> m r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ALens' (S p s) a -> LensLike' ((,) r) (S p s) a
forall s t a b. ALens s t a b -> Lens s t a b
cloneLens
{-# INLINE runASLens #-}

-- | A bare state transition representing a read operation.
stateRead :: a -> (a, a)
stateRead :: a -> (a, a)
stateRead a
a = (a
a, a
a)
{-# INLINE stateRead #-}

-- | A bare state transition representing a write operation.
--
-- @'stateWrite' b@ can be passed to 'runSLens' to write @b@ to the reference.
stateWrite :: b -> a -> ((), b)
stateWrite :: b -> a -> ((), b)
stateWrite b
b a
a = ((), b
b)
{-# INLINE stateWrite #-}

-- | A bare state transition representing a modify/map operation.
--
-- @'stateModify' f@ can be passed to 'runSLens' to apply @f@ to the reference.
stateModify :: (a -> b) -> a -> ((), b)
stateModify :: (a -> b) -> a -> ((), b)
stateModify a -> b
f a
a = ((), a -> b
f (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$! a
a)
{-# INLINE stateModify #-}