{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unused-do-bind #-}
module TypedProtocol.Driver where
import Control.Concurrent.Class.MonadSTM
import Control.Monad.Class.MonadThrow (MonadThrow, throwIO)
import Data.IFunctor (At (..), Sing, SingI (sing))
import qualified Data.IFunctor as I
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import GHC.Exception (Exception)
import TypedProtocol.Codec
import TypedProtocol.Core
import Unsafe.Coerce (unsafeCoerce)
data Driver role' ps m
= Driver
{ forall role' ps (m :: * -> *).
Driver role' ps m
-> forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg
:: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps) (st'' :: ps)
. ( SingI recv
, SingI st
, SingToInt ps
, SingToInt role'
)
=> Sing recv
-> Msg role' ps st '(send, st') '(recv, st'')
-> m ()
, forall role' ps (m :: * -> *).
Driver role' ps m
-> forall (st' :: ps).
SingToInt ps =>
Sing st' -> m (AnyMsg role' ps)
recvMsg
:: forall (st' :: ps)
. (SingToInt ps)
=> Sing st'
-> m (AnyMsg role' ps)
}
runPeerWithDriver
:: forall role' ps (r :: role') (st :: ps) m a
. ( Monad m
, (SingToInt role')
)
=> Driver role' ps m
-> Peer role' ps r m (At a (Done r)) st
-> m a
runPeerWithDriver :: forall role' ps (r :: role') (st :: ps) (m :: * -> *) a.
(Monad m, SingToInt role') =>
Driver role' ps m -> Peer role' ps r m (At a (Done r)) st -> m a
runPeerWithDriver Driver{forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg :: forall role' ps (m :: * -> *).
Driver role' ps m
-> forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg, forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg :: forall role' ps (m :: * -> *).
Driver role' ps m
-> forall (st' :: ps).
SingToInt ps =>
Sing st' -> m (AnyMsg role' ps)
recvMsg :: forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg} =
Peer role' ps r m (At a (Done r)) st -> m a
forall (st' :: ps). Peer role' ps r m (At a (Done r)) st' -> m a
go
where
go
:: forall st'
. Peer role' ps r m (At a (Done r)) st'
-> m a
go :: forall (st' :: ps). Peer role' ps r m (At a (Done r)) st' -> m a
go (IReturn (At a
a)) = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
go (LiftM m (Peer role' ps r m (At a (Done r)) st')
k) = m (Peer role' ps r m (At a (Done r)) st')
k m (Peer role' ps r m (At a (Done r)) st')
-> (Peer role' ps r m (At a (Done r)) st' -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Peer role' ps r m (At a (Done r)) st' -> m a
forall (st' :: ps). Peer role' ps r m (At a (Done r)) st' -> m a
go
go (Yield (Msg role' ps st' '(r, sps) '(recv, rps)
msg :: Msg role' ps (st' :: ps) '(r, sps) '(recv :: role', rps)) Peer role' ps r m (At a (Done r)) sps
k) = do
Sing recv -> Msg role' ps st' '(r, sps) '(recv, rps) -> m ()
forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg (forall (a :: role'). SingI a => Sing a
forall {k} (a :: k). SingI a => Sing a
sing @recv) Msg role' ps st' '(r, sps) '(recv, rps)
msg
Peer role' ps r m (At a (Done r)) sps -> m a
forall (st' :: ps). Peer role' ps r m (At a (Done r)) st' -> m a
go Peer role' ps r m (At a (Done r)) sps
k
go (Await (Recv role' ps r st' ~> Peer role' ps r m (At a (Done r))
k :: (Recv role' ps r st' I.~> Peer role' ps r m ia))) = do
AnyMsg msg <- Sing st' -> m (AnyMsg role' ps)
forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg (forall (a :: ps). SingI a => Sing a
forall {k} (a :: k). SingI a => Sing a
sing @st')
go (k $ unsafeCoerce (Recv msg))
data TraceSendRecv role' ps where
TraceSendMsg :: AnyMsg role' ps -> TraceSendRecv role' ps
TraceRecvMsg :: AnyMsg role' ps -> TraceSendRecv role' ps
instance (Show (AnyMsg role' ps)) => Show (TraceSendRecv role' ps) where
show :: TraceSendRecv role' ps -> String
show (TraceSendMsg AnyMsg role' ps
msg) = String
"Send " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AnyMsg role' ps -> String
forall a. Show a => a -> String
show AnyMsg role' ps
msg
show (TraceRecvMsg AnyMsg role' ps
msg) = String
"Recv " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AnyMsg role' ps -> String
forall a. Show a => a -> String
show AnyMsg role' ps
msg
type Tracer role' ps m = TraceSendRecv role' ps -> m ()
nullTracer :: (Monad m) => a -> m ()
nullTracer :: forall (m :: * -> *) a. Monad m => a -> m ()
nullTracer a
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
type SendMap role' m bytes = IntMap (bytes -> m ())
driverSimple
:: forall role' ps bytes m n
. ( Monad m
, Monad n
, MonadSTM n
)
=> Tracer role' ps n
-> Encode role' ps bytes
-> SendMap role' n bytes
-> TVar n (MsgCache role' ps)
-> (forall a. n a -> m a)
-> Driver role' ps m
driverSimple :: forall role' ps bytes (m :: * -> *) (n :: * -> *).
(Monad m, Monad n, MonadSTM n) =>
Tracer role' ps n
-> Encode role' ps bytes
-> SendMap role' n bytes
-> TVar n (MsgCache role' ps)
-> (forall a. n a -> m a)
-> Driver role' ps m
driverSimple Tracer role' ps n
tracer Encode{forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
Msg role' ps st '(send, st') '(recv, st'') -> bytes
encode :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
Msg role' ps st '(send, st') '(recv, st'') -> bytes
encode :: forall role' ps bytes.
Encode role' ps bytes
-> forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
Msg role' ps st '(send, st') '(recv, st'') -> bytes
encode} SendMap role' n bytes
sendMap TVar n (MsgCache role' ps)
tvar forall a. n a -> m a
liftFun =
Driver{Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg, Sing st' -> m (AnyMsg role' ps)
forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg :: forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg :: forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg}
where
sendMsg
:: forall (send :: role') (recv :: role') (from :: ps) (st :: ps) (st1 :: ps)
. ( SingI recv
, SingI from
, SingToInt ps
, SingToInt role'
)
=> Sing recv
-> Msg role' ps from '(send, st) '(recv, st1)
-> m ()
sendMsg :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg Sing recv
role Msg role' ps from '(send, st) '(recv, st1)
msg = n () -> m ()
forall a. n a -> m a
liftFun (n () -> m ()) -> n () -> m ()
forall a b. (a -> b) -> a -> b
$ do
case Int -> SendMap role' n bytes -> Maybe (bytes -> n ())
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup (Sing recv -> Int
forall (r :: role'). Sing r -> Int
forall s (r :: s). SingToInt s => Sing r -> Int
singToInt Sing recv
role) SendMap role' n bytes
sendMap of
Maybe (bytes -> n ())
Nothing -> String -> n ()
forall a. HasCallStack => String -> a
error String
"np"
Just bytes -> n ()
sendFun -> bytes -> n ()
sendFun (Msg role' ps from '(send, st) '(recv, st1) -> bytes
forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
(st'' :: ps).
Msg role' ps st '(send, st') '(recv, st'') -> bytes
encode Msg role' ps from '(send, st) '(recv, st1)
msg)
Tracer role' ps n
tracer (AnyMsg role' ps -> TraceSendRecv role' ps
forall role' ps. AnyMsg role' ps -> TraceSendRecv role' ps
TraceSendMsg (Msg role' ps from '(send, st) '(recv, st1) -> AnyMsg role' ps
forall role' (recv :: role') ps (st :: ps) (send :: role')
(st' :: ps) (st'' :: ps).
(SingI recv, SingI st, SingToInt role', SingToInt ps) =>
Msg role' ps st '(send, st') '(recv, st'') -> AnyMsg role' ps
AnyMsg Msg role' ps from '(send, st) '(recv, st1)
msg))
recvMsg
:: forall (st' :: ps)
. (SingToInt ps)
=> Sing st'
-> m (AnyMsg role' ps)
recvMsg :: forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg Sing st'
sst' = do
let singInt :: Int
singInt = Sing st' -> Int
forall (r :: ps). Sing r -> Int
forall s (r :: s). SingToInt s => Sing r -> Int
singToInt Sing st'
sst'
n (AnyMsg role' ps) -> m (AnyMsg role' ps)
forall a. n a -> m a
liftFun (n (AnyMsg role' ps) -> m (AnyMsg role' ps))
-> n (AnyMsg role' ps) -> m (AnyMsg role' ps)
forall a b. (a -> b) -> a -> b
$ do
anyMsg <- STM n (AnyMsg role' ps) -> n (AnyMsg role' ps)
forall a. HasCallStack => STM n a -> n a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM n (AnyMsg role' ps) -> n (AnyMsg role' ps))
-> STM n (AnyMsg role' ps) -> n (AnyMsg role' ps)
forall a b. (a -> b) -> a -> b
$ do
agencyMsg <- TVar n (MsgCache role' ps) -> STM n (MsgCache role' ps)
forall a. TVar n a -> STM n a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar n (MsgCache role' ps)
tvar
case IntMap.lookup singInt agencyMsg of
Maybe (AnyMsg role' ps)
Nothing -> STM n (AnyMsg role' ps)
forall a. STM n a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
Just AnyMsg role' ps
v -> do
TVar n (MsgCache role' ps) -> MsgCache role' ps -> STM n ()
forall a. TVar n a -> a -> STM n ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar n (MsgCache role' ps)
tvar (Int -> MsgCache role' ps -> MsgCache role' ps
forall a. Int -> IntMap a -> IntMap a
IntMap.delete Int
singInt MsgCache role' ps
agencyMsg)
AnyMsg role' ps -> STM n (AnyMsg role' ps)
forall a. a -> STM n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AnyMsg role' ps
v
tracer (TraceRecvMsg (anyMsg))
pure anyMsg
decodeLoop
:: (Exception failure, MonadSTM n, MonadThrow n)
=> Tracer role' ps n
-> Maybe bytes
-> Decode role' ps failure bytes
-> Channel n bytes
-> TVar n (MsgCache role' ps)
-> n ()
decodeLoop :: forall failure (n :: * -> *) role' ps bytes.
(Exception failure, MonadSTM n, MonadThrow n) =>
Tracer role' ps n
-> Maybe bytes
-> Decode role' ps failure bytes
-> Channel n bytes
-> TVar n (MsgCache role' ps)
-> n ()
decodeLoop Tracer role' ps n
tracer Maybe bytes
mbt d :: Decode role' ps failure bytes
d@Decode{DecodeStep bytes failure (AnyMsg role' ps)
decode :: DecodeStep bytes failure (AnyMsg role' ps)
decode :: forall role' ps failure bytes.
Decode role' ps failure bytes
-> DecodeStep bytes failure (AnyMsg role' ps)
decode} Channel n bytes
channel TVar n (MsgCache role' ps)
tvar = do
result <- Channel n bytes
-> Maybe bytes
-> DecodeStep bytes failure (AnyMsg role' ps)
-> n (Either failure (AnyMsg role' ps, Maybe bytes))
forall (m :: * -> *) bytes failure a.
Monad m =>
Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure a
-> m (Either failure (a, Maybe bytes))
runDecoderWithChannel Channel n bytes
channel Maybe bytes
mbt DecodeStep bytes failure (AnyMsg role' ps)
decode
case result of
Right (AnyMsg Msg role' ps st '(send, st') '(recv, st'')
msg, Maybe bytes
mbt') -> do
let agencyInt :: Int
agencyInt = Sing st -> Int
forall (r :: ps). Sing r -> Int
forall s (r :: s). SingToInt s => Sing r -> Int
singToInt (Sing st -> Int) -> Sing st -> Int
forall a b. (a -> b) -> a -> b
$ Msg role' ps st '(send, st') '(recv, st'') -> Sing st
forall role' ps (st :: ps) (send :: role') (recv :: role')
(st' :: ps) (st'' :: ps).
(SingI recv, SingI st) =>
Msg role' ps st '(send, st') '(recv, st'') -> Sing st
msgFromStSing Msg role' ps st '(send, st') '(recv, st'')
msg
STM n () -> n ()
forall a. HasCallStack => STM n a -> n a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM n () -> n ()) -> STM n () -> n ()
forall a b. (a -> b) -> a -> b
$ do
agencyMsg <- TVar n (MsgCache role' ps) -> STM n (MsgCache role' ps)
forall a. TVar n a -> STM n a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar n (MsgCache role' ps)
tvar
case IntMap.lookup agencyInt agencyMsg of
Maybe (AnyMsg role' ps)
Nothing -> TVar n (MsgCache role' ps) -> MsgCache role' ps -> STM n ()
forall a. TVar n a -> a -> STM n ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar n (MsgCache role' ps)
tvar (Int -> AnyMsg role' ps -> MsgCache role' ps -> MsgCache role' ps
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
agencyInt (Msg role' ps st '(send, st') '(recv, st'') -> AnyMsg role' ps
forall role' (recv :: role') ps (st :: ps) (send :: role')
(st' :: ps) (st'' :: ps).
(SingI recv, SingI st, SingToInt role', SingToInt ps) =>
Msg role' ps st '(send, st') '(recv, st'') -> AnyMsg role' ps
AnyMsg Msg role' ps st '(send, st') '(recv, st'')
msg) MsgCache role' ps
agencyMsg)
Just AnyMsg role' ps
_v -> STM n ()
forall a. STM n a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
Tracer role' ps n
-> Maybe bytes
-> Decode role' ps failure bytes
-> Channel n bytes
-> TVar n (MsgCache role' ps)
-> n ()
forall failure (n :: * -> *) role' ps bytes.
(Exception failure, MonadSTM n, MonadThrow n) =>
Tracer role' ps n
-> Maybe bytes
-> Decode role' ps failure bytes
-> Channel n bytes
-> TVar n (MsgCache role' ps)
-> n ()
decodeLoop Tracer role' ps n
tracer Maybe bytes
mbt' Decode role' ps failure bytes
d Channel n bytes
channel TVar n (MsgCache role' ps)
tvar
Left failure
failure -> failure -> n ()
forall e a. Exception e => e -> n a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO failure
failure