-- | Working with 'Foreign.Ptr's in a way that prevents use after free
--
-- >>> :set -XPostfixOperators
-- >>> import Control.Monad.Scoped.Internal
-- >>> scoped do x <- newPtr (69 :: Word); x `setPtr` 42; getPtr x
-- 42
--
-- @since 0.1.0.0
module Control.Monad.Scoped.Ptr
  ( Ptr
  , foreignPtr
  , wrapScoped
  , newPtr
  , setPtr
  , getPtr
  )
where

import Control.Monad.IO.Unlift (MonadIO (liftIO), MonadUnliftIO (withRunInIO))
import Control.Monad.Scoped.Internal (Scoped (UnsafeMkScoped), ScopedResource (UnsafeMkScopedResource, unsafeUnwrapScopedResource), bracketScoped, (:<))
import Control.Monad.Trans.Class (lift)
import Foreign qualified

-- | A 'Foreign.Ptr' that is associated to a scope
--
-- @since 0.1.0.0
type Ptr s a = ScopedResource s (Foreign.Ptr a)

-- | Acquire mutable memory for the duration of a scope. The value is automatically dropped at the end of the scope.
--
-- @since 0.2.0.0
newPtr :: (Foreign.Storable a, MonadUnliftIO m) => a -> Scoped (s : ss) m (Ptr s a)
newPtr :: forall a (m :: Type -> Type) s (ss :: [Type]).
(Storable a, MonadUnliftIO m) =>
a -> Scoped (s : ss) m (Ptr s a)
newPtr a
a = m (Ptr a)
-> (Ptr a -> m ()) -> Scoped (s : ss) m (ScopedResource s (Ptr a))
forall (m :: Type -> Type) a b s (ss :: [Type]).
MonadUnliftIO m =>
m a -> (a -> m b) -> Scoped (s : ss) m (ScopedResource s a)
bracketScoped (IO (Ptr a) -> m (Ptr a)
forall a. IO a -> m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (IO (Ptr a) -> m (Ptr a)) -> IO (Ptr a) -> m (Ptr a)
forall a b. (a -> b) -> a -> b
$ a -> IO (Ptr a)
forall a. Storable a => a -> IO (Ptr a)
Foreign.new a
a) (IO () -> m ()
forall a. IO a -> m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Ptr a -> IO ()) -> Ptr a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> IO ()
forall a. Ptr a -> IO ()
Foreign.free)

-- | write a value to a pointer
--
-- @since 0.2.0.0
setPtr :: (Foreign.Storable a, MonadIO m, s :< ss) => Ptr s a -> a -> Scoped ss m ()
setPtr :: forall a (m :: Type -> Type) s (ss :: [Type]).
(Storable a, MonadIO m, s :< ss) =>
Ptr s a -> a -> Scoped ss m ()
setPtr Ptr s a
ptr = IO () -> Scoped ss m ()
forall a. IO a -> Scoped ss m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (IO () -> Scoped ss m ()) -> (a -> IO ()) -> a -> Scoped ss m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
Foreign.poke (Ptr s a -> Ptr a
forall s a. ScopedResource s a -> a
unsafeUnwrapScopedResource Ptr s a
ptr)

-- | read a value from a pointer
--
-- @since 0.2.0.0
getPtr :: (Foreign.Storable a, MonadIO m, s :< ss) => Ptr s a -> Scoped ss m a
getPtr :: forall a (m :: Type -> Type) s (ss :: [Type]).
(Storable a, MonadIO m, s :< ss) =>
Ptr s a -> Scoped ss m a
getPtr Ptr s a
ptr = IO a -> Scoped ss m a
forall a. IO a -> Scoped ss m a
forall (m :: Type -> Type) a. MonadIO m => IO a -> m a
liftIO (Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
Foreign.peek (Ptr s a -> Ptr a
forall s a. ScopedResource s a -> a
unsafeUnwrapScopedResource Ptr s a
ptr))

-- | this is a wrapper around 'Foreign.withForeignPtr' to allow for safe usage of this function in a scope
--
-- @since 0.1.0.2
foreignPtr :: MonadUnliftIO m => Foreign.ForeignPtr a -> Scoped (s : ss) m (Ptr s a)
foreignPtr :: forall (m :: Type -> Type) a s (ss :: [Type]).
MonadUnliftIO m =>
ForeignPtr a -> Scoped (s : ss) m (Ptr s a)
foreignPtr ForeignPtr a
fptr = (forall b. (Ptr s a -> m b) -> m b) -> Scoped (s : ss) m (Ptr s a)
forall {k} (s :: [Type]) (m :: k -> Type) a.
(forall (b :: k). (a -> m b) -> m b) -> Scoped s m a
UnsafeMkScoped \Ptr s a -> m b
k -> ((forall a. m a -> IO a) -> IO b) -> m b
forall b. ((forall a. m a -> IO a) -> IO b) -> m b
forall (m :: Type -> Type) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO \forall a. m a -> IO a
inIO -> ForeignPtr a -> (Ptr a -> IO b) -> IO b
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
Foreign.withForeignPtr ForeignPtr a
fptr (m b -> IO b
forall a. m a -> IO a
inIO (m b -> IO b) -> (Ptr a -> m b) -> Ptr a -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr s a -> m b
k (Ptr s a -> m b) -> (Ptr a -> Ptr s a) -> Ptr a -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> Ptr s a
forall s a. a -> ScopedResource s a
UnsafeMkScopedResource)

-- | takes a function that does something with a 'Foreign.Ptr' and makes it safe
--
-- @since 0.1.0.2
wrapScoped :: (Monad m, s :< ss) => (Foreign.Ptr a -> m r) -> Ptr s a -> Scoped ss m r
wrapScoped :: forall (m :: Type -> Type) s (ss :: [Type]) a r.
(Monad m, s :< ss) =>
(Ptr a -> m r) -> Ptr s a -> Scoped ss m r
wrapScoped Ptr a -> m r
k Ptr s a
p = m r -> Scoped ss m r
forall (m :: Type -> Type) a. Monad m => m a -> Scoped ss m a
forall (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m r -> Scoped ss m r) -> m r -> Scoped ss m r
forall a b. (a -> b) -> a -> b
$ Ptr a -> m r
k (Ptr s a -> Ptr a
forall s a. ScopedResource s a -> a
unsafeUnwrapScopedResource Ptr s a
p)