{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}

-- | Module: Capnp.Rpc.Membrane
-- Descritpion: Helpers for working with membranes.
--
-- Membranes are common in object-capability design. Think of it like a
-- proxy on steroids: a membrane inserts itself in front of another capability,
-- and can intercept and modify method calls. Unlike a simple proxy though,
-- the membrane will also be applied to any objects returned by method calls,
-- or passed in arguments, transitively, so it can sit in front of entire
-- object graphs.
module Capnp.Rpc.Membrane
  ( enclose,
    exclude,
    Policy,
    Action (..),
    Direction (..),
    Call (..),
  )
where

import qualified Capnp.Message as M
import Capnp.Mutability (Mutability (..))
import Capnp.Rpc.Promise (breakOrFulfill, newCallback)
import qualified Capnp.Rpc.Server as Server
import qualified Capnp.Rpc.Untyped as URpc
import qualified Capnp.Untyped as U
import Control.Concurrent.STM
import Control.Monad (void)
import Control.Monad.STM.Class
import Data.Typeable (Typeable, cast)
import Data.Word
import Supervisors (Supervisor)

-- | An action indicates what to do with an incoming method call.
data Action
  = -- | Handle the method using the provided method handler, instead of
    -- letting it through the membrane. Arguments and return values will not
    -- be wrapped/unwraped, so be careful when delegating to objects inside
    -- the membrane.
    Handle Server.UntypedMethodHandler
  | -- | Forward the method call on to its original destination, wrapping
    -- and unwrapping arguments & return values as normal.
    Forward

-- | A Direction indicates which direction a method call is traveling:
-- into or out of the membrane.
data Direction = In | Out
  deriving (Int -> Direction -> ShowS
[Direction] -> ShowS
Direction -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Direction] -> ShowS
$cshowList :: [Direction] -> ShowS
show :: Direction -> String
$cshow :: Direction -> String
showsPrec :: Int -> Direction -> ShowS
$cshowsPrec :: Int -> Direction -> ShowS
Show, ReadPrec [Direction]
ReadPrec Direction
Int -> ReadS Direction
ReadS [Direction]
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Direction]
$creadListPrec :: ReadPrec [Direction]
readPrec :: ReadPrec Direction
$creadPrec :: ReadPrec Direction
readList :: ReadS [Direction]
$creadList :: ReadS [Direction]
readsPrec :: Int -> ReadS Direction
$creadsPrec :: Int -> ReadS Direction
Read, Direction -> Direction -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Direction -> Direction -> Bool
$c/= :: Direction -> Direction -> Bool
== :: Direction -> Direction -> Bool
$c== :: Direction -> Direction -> Bool
Eq)

flipDir :: Direction -> Direction
flipDir :: Direction -> Direction
flipDir Direction
In = Direction
Out
flipDir Direction
Out = Direction
In

-- | Alias for direction; somtimes it is convienent to think about capabilities
-- from the standpoint of which side they is _on_, rather than where it is going
-- as with methods.
type Side = Direction

-- | A 'Call' represents a method call that is crossing the membrane.
data Call = Call
  { -- | Which direction is the call going? if this is 'In', the call was made
    -- by something outside the membrane to something inside it. If it is 'Out',
    -- something inside the membrane is making a call to something outside the
    -- membrane.
    Call -> Direction
direction :: Direction,
    -- | The interface id of the method being called.
    Call -> Word64
interfaceId :: Word64,
    -- | The ordinal of the method being called.
    Call -> Word16
methodId :: Word16,
    -- | The target of the method call.
    Call -> Client
target :: URpc.Client
  }

-- | A 'Policy' decides what to do when a call crosses the membrane.
type Policy = Call -> STM Action

-- | @'enclose' sup cap policy@ wraps @cap@ in a membrane whose behavior is
-- goverend by @policy@.
enclose :: (URpc.IsClient c, MonadSTM m) => Supervisor -> c -> Policy -> m c
enclose :: forall c (m :: * -> *).
(IsClient c, MonadSTM m) =>
Supervisor -> c -> Policy -> m c
enclose = forall c (m :: * -> *).
(IsClient c, MonadSTM m) =>
Direction -> Supervisor -> c -> Policy -> m c
newMembrane Direction
In

-- | 'exclude' is like 'enclose', except that the capability is treated as
-- being *outside* of a membrane that wraps the rest of the world.
exclude :: (URpc.IsClient c, MonadSTM m) => Supervisor -> c -> Policy -> m c
exclude :: forall c (m :: * -> *).
(IsClient c, MonadSTM m) =>
Supervisor -> c -> Policy -> m c
exclude = forall c (m :: * -> *).
(IsClient c, MonadSTM m) =>
Direction -> Supervisor -> c -> Policy -> m c
newMembrane Direction
Out

newMembrane :: (URpc.IsClient c, MonadSTM m) => Direction -> Supervisor -> c -> Policy -> m c
newMembrane :: forall c (m :: * -> *).
(IsClient c, MonadSTM m) =>
Direction -> Supervisor -> c -> Policy -> m c
newMembrane Direction
dir Supervisor
sup c
toWrap Policy
policy = forall (m :: * -> *) a. MonadSTM m => STM a -> m a
liftSTM forall a b. (a -> b) -> a -> b
$ do
  TVar ()
identity <- forall a. a -> STM (TVar a)
newTVar ()
  let mem :: Membrane
mem = Membrane {Policy
policy :: Policy
policy :: Policy
policy, TVar ()
identity :: TVar ()
identity :: TVar ()
identity}
  forall a. IsClient a => Client -> a
URpc.fromClient forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadSTM m =>
Direction -> Supervisor -> Membrane -> Client -> m Client
pass Direction
dir Supervisor
sup Membrane
mem (forall a. IsClient a => a -> Client
URpc.toClient c
toWrap)

data MembraneWrapped = MembraneWrapped
  { MembraneWrapped -> Client
client :: URpc.Client,
    MembraneWrapped -> Membrane
membrane :: Membrane,
    MembraneWrapped -> Direction
side :: Direction
  }
  deriving (Typeable)

data Membrane = Membrane
  { Membrane -> Policy
policy :: Policy,
    -- | an object with identity, for comparison purposes:
    Membrane -> TVar ()
identity :: TVar ()
  }

instance Eq Membrane where
  Membrane
x == :: Membrane -> Membrane -> Bool
== Membrane
y = Membrane -> TVar ()
identity Membrane
x forall a. Eq a => a -> a -> Bool
== Membrane -> TVar ()
identity Membrane
y

wrapHandler :: Side -> Supervisor -> Membrane -> Server.UntypedMethodHandler -> Server.UntypedMethodHandler
wrapHandler :: Direction
-> Supervisor
-> Membrane
-> UntypedMethodHandler
-> UntypedMethodHandler
wrapHandler Direction
receiverSide Supervisor
sup Membrane
mem UntypedMethodHandler
handler = (Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> IO ())
-> UntypedMethodHandler
Server.untypedHandler forall a b. (a -> b) -> a -> b
$ \Maybe (Ptr 'Const)
arguments Fulfiller (Maybe (Ptr 'Const))
response -> do
  (Maybe (Ptr 'Const)
args, Fulfiller (Maybe (Ptr 'Const))
resp) <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
    Maybe (Ptr 'Const)
args' <- forall (m :: * -> *).
MonadSTM m =>
Direction
-> Supervisor
-> Membrane
-> Maybe (Ptr 'Const)
-> m (Maybe (Ptr 'Const))
passPtr Direction
receiverSide Supervisor
sup Membrane
mem Maybe (Ptr 'Const)
arguments
    Fulfiller (Maybe (Ptr 'Const))
resp' <- forall (m :: * -> *) a.
MonadSTM m =>
(Either (Parsed Exception) a -> STM ()) -> m (Fulfiller a)
newCallback forall a b. (a -> b) -> a -> b
$ \Either (Parsed Exception) (Maybe (Ptr 'Const))
result ->
      forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadSTM m =>
Direction
-> Supervisor
-> Membrane
-> Maybe (Ptr 'Const)
-> m (Maybe (Ptr 'Const))
passPtr (Direction -> Direction
flipDir Direction
receiverSide) Supervisor
sup Membrane
mem) Either (Parsed Exception) (Maybe (Ptr 'Const))
result
        forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a.
MonadSTM m =>
Fulfiller a -> Either (Parsed Exception) a -> m ()
breakOrFulfill Fulfiller (Maybe (Ptr 'Const))
response
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Ptr 'Const)
args', Fulfiller (Maybe (Ptr 'Const))
resp')
  UntypedMethodHandler
-> Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> IO ()
Server.handleUntypedMethod UntypedMethodHandler
handler Maybe (Ptr 'Const)
args Fulfiller (Maybe (Ptr 'Const))
resp

passPtr :: MonadSTM m => Direction -> Supervisor -> Membrane -> Maybe (U.Ptr 'Const) -> m (Maybe (U.Ptr 'Const))
passPtr :: forall (m :: * -> *).
MonadSTM m =>
Direction
-> Supervisor
-> Membrane
-> Maybe (Ptr 'Const)
-> m (Maybe (Ptr 'Const))
passPtr Direction
dir Supervisor
sup Membrane
mem = forall (m :: * -> *) a. MonadSTM m => STM a -> m a
liftSTM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (f :: Mutability -> *) (m :: * -> *) (mutA :: Mutability)
       (mutB :: Mutability).
(TraverseMsg f, TraverseMsgCtx m mutA mutB) =>
(Message mutA -> m (Message mutB)) -> f mutA -> m (f mutB)
U.tMsg forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadSTM m =>
Direction
-> Supervisor -> Membrane -> Message 'Const -> m (Message 'Const)
passMessage Direction
dir Supervisor
sup Membrane
mem)

passMessage :: MonadSTM m => Direction -> Supervisor -> Membrane -> M.Message 'Const -> m (M.Message 'Const)
passMessage :: forall (m :: * -> *).
MonadSTM m =>
Direction
-> Supervisor -> Membrane -> Message 'Const -> m (Message 'Const)
passMessage Direction
dir Supervisor
sup Membrane
mem Message 'Const
msg = forall (m :: * -> *) a. MonadSTM m => STM a -> m a
liftSTM forall a b. (a -> b) -> a -> b
$ do
  Vector Client
caps' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (m :: * -> *).
MonadSTM m =>
Direction -> Supervisor -> Membrane -> Client -> m Client
pass Direction
dir Supervisor
sup Membrane
mem) (Message 'Const -> Vector Client
M.getCapTable Message 'Const
msg)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Vector Client -> Message 'Const -> Message 'Const
M.withCapTable Vector Client
caps' Message 'Const
msg

pass :: MonadSTM m => Direction -> Supervisor -> Membrane -> URpc.Client -> m URpc.Client
pass :: forall (m :: * -> *).
MonadSTM m =>
Direction -> Supervisor -> Membrane -> Client -> m Client
pass Direction
dir Supervisor
sup Membrane
mem Client
inClient = forall (m :: * -> *) a. MonadSTM m => STM a -> m a
liftSTM forall a b. (a -> b) -> a -> b
$
  case forall c a. (IsClient c, Typeable a) => c -> Maybe a
URpc.unwrapServer Client
inClient :: Maybe MembraneWrapped of
    Just MembraneWrapped
mw | Direction -> MembraneWrapped -> Membrane -> Bool
onSide Direction
dir MembraneWrapped
mw Membrane
mem -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ MembraneWrapped -> Client
client MembraneWrapped
mw
    Maybe MembraneWrapped
_ ->
      forall (m :: * -> *).
MonadSTM m =>
Supervisor -> ServerOps -> m Client
URpc.export
        Supervisor
sup
        Server.ServerOps
          { handleCast :: forall a. Typeable a => Maybe a
Server.handleCast =
              forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast forall a b. (a -> b) -> a -> b
$
                MembraneWrapped
                  { client :: Client
client = Client
inClient,
                    membrane :: Membrane
membrane = Membrane
mem,
                    side :: Direction
side = Direction
dir
                  },
            -- Once we're gc'd, the downstream client will be as well
            -- and then the relevant shutdown logic will still be called.
            -- This introduces latency unfortuantely, but nothing is broken.
            handleStop :: IO ()
Server.handleStop = forall (f :: * -> *) a. Applicative f => a -> f a
pure (),
            handleCall :: Word64 -> Word16 -> UntypedMethodHandler
Server.handleCall = \Word64
interfaceId Word16
methodId ->
              (Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> IO ())
-> UntypedMethodHandler
Server.untypedHandler forall a b. (a -> b) -> a -> b
$ \Maybe (Ptr 'Const)
arguments Fulfiller (Maybe (Ptr 'Const))
response ->
                do
                  Action
action <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ Membrane -> Policy
policy Membrane
mem Call {Word64
interfaceId :: Word64
interfaceId :: Word64
interfaceId, Word16
methodId :: Word16
methodId :: Word16
methodId, direction :: Direction
direction = Direction
dir, target :: Client
target = Client
inClient}
                  case Action
action of
                    Handle UntypedMethodHandler
h -> UntypedMethodHandler
-> Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> IO ()
Server.handleUntypedMethod UntypedMethodHandler
h Maybe (Ptr 'Const)
arguments Fulfiller (Maybe (Ptr 'Const))
response
                    Action
Forward ->
                      UntypedMethodHandler
-> Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> IO ()
Server.handleUntypedMethod
                        ( Direction
-> Supervisor
-> Membrane
-> UntypedMethodHandler
-> UntypedMethodHandler
wrapHandler Direction
dir Supervisor
sup Membrane
mem forall a b. (a -> b) -> a -> b
$ (Maybe (Ptr 'Const) -> Fulfiller (Maybe (Ptr 'Const)) -> IO ())
-> UntypedMethodHandler
Server.untypedHandler forall a b. (a -> b) -> a -> b
$ \Maybe (Ptr 'Const)
arguments Fulfiller (Maybe (Ptr 'Const))
response ->
                            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadSTM m =>
CallInfo -> Client -> m (Promise Pipeline)
URpc.call Server.CallInfo {Maybe (Ptr 'Const)
Word16
Word64
Fulfiller (Maybe (Ptr 'Const))
response :: Fulfiller (Maybe (Ptr 'Const))
arguments :: Maybe (Ptr 'Const)
methodId :: Word16
interfaceId :: Word64
response :: Fulfiller (Maybe (Ptr 'Const))
arguments :: Maybe (Ptr 'Const)
methodId :: Word16
interfaceId :: Word64
..} Client
inClient
                        )
                        Maybe (Ptr 'Const)
arguments
                        Fulfiller (Maybe (Ptr 'Const))
response
          }

onSide :: Direction -> MembraneWrapped -> Membrane -> Bool
onSide :: Direction -> MembraneWrapped -> Membrane -> Bool
onSide Direction
dir MembraneWrapped
mw Membrane
mem =
  MembraneWrapped -> Membrane
membrane MembraneWrapped
mw forall a. Eq a => a -> a -> Bool
== Membrane
mem Bool -> Bool -> Bool
&& Direction
dir forall a. Eq a => a -> a -> Bool
== MembraneWrapped -> Direction
side MembraneWrapped
mw