--  This part of the code comes from typed-protocols, I modified a few things.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unused-do-bind #-}

{- |
Schematic diagram of the communication structure of three roles through typed-session:

<<data/fm.png>>

Some explanations for this diagram:

1. Roles are connected through channels, and there are many types of channels, such as channels established through TCP or channels established through TMVar.

2. Each role has a Peer thread, in which the Peer runs.

3. Each role has one or more decode threads, and the decoded Msgs are placed in the MsgCache.

4. SendMap aggregates the send functions of multiple Channels together.
When sending a message, the send function of the receiver is searched from SendMap.
-}
module TypedProtocol.Driver where

import Control.Concurrent.Class.MonadSTM
import Control.Monad.Class.MonadThrow (MonadThrow, throwIO)
import Data.IFunctor (At (..), Sing, SingI (sing))
import qualified Data.IFunctor as I
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import GHC.Exception (Exception)
import TypedProtocol.Codec
import TypedProtocol.Core
import Unsafe.Coerce (unsafeCoerce)

{- |
Contains two functions sendMsg, recvMsg.
runPeerWithDriver uses them to send and receive Msg.
-}
data Driver role' ps m
  = Driver
  { forall role' ps (m :: * -> *).
Driver role' ps m
-> forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
          (st'' :: ps).
   (SingI recv, SingI st, SingToInt ps, SingToInt role') =>
   Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg
      :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps) (st'' :: ps)
       . ( SingI recv
         , SingI st
         , SingToInt ps
         , SingToInt role'
         )
      => Sing recv
      -> Msg role' ps st '(send, st') '(recv, st'')
      -> m ()
  , forall role' ps (m :: * -> *).
Driver role' ps m
-> forall (st' :: ps).
   SingToInt ps =>
   Sing st' -> m (AnyMsg role' ps)
recvMsg
      :: forall (st' :: ps)
       . (SingToInt ps)
      => Sing st'
      -> m (AnyMsg role' ps)
  }

{- |
Interpret Peer.
-}
runPeerWithDriver
  :: forall role' ps (r :: role') (st :: ps) m a
   . ( Monad m
     , (SingToInt role')
     )
  => Driver role' ps m
  -> Peer role' ps r m (At a (Done r)) st
  -> m a
runPeerWithDriver :: forall role' ps (r :: role') (st :: ps) (m :: * -> *) a.
(Monad m, SingToInt role') =>
Driver role' ps m -> Peer role' ps r m (At a (Done r)) st -> m a
runPeerWithDriver Driver{forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg :: forall role' ps (m :: * -> *).
Driver role' ps m
-> forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
          (st'' :: ps).
   (SingI recv, SingI st, SingToInt ps, SingToInt role') =>
   Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg, forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg :: forall role' ps (m :: * -> *).
Driver role' ps m
-> forall (st' :: ps).
   SingToInt ps =>
   Sing st' -> m (AnyMsg role' ps)
recvMsg :: forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg} =
  Peer role' ps r m (At a (Done r)) st -> m a
forall (st' :: ps). Peer role' ps r m (At a (Done r)) st' -> m a
go
 where
  go
    :: forall st'
     . Peer role' ps r m (At a (Done r)) st'
    -> m a
  go :: forall (st' :: ps). Peer role' ps r m (At a (Done r)) st' -> m a
go (IReturn (At a
a)) = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
  go (LiftM m (Peer role' ps r m (At a (Done r)) st')
k) = m (Peer role' ps r m (At a (Done r)) st')
k m (Peer role' ps r m (At a (Done r)) st')
-> (Peer role' ps r m (At a (Done r)) st' -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Peer role' ps r m (At a (Done r)) st' -> m a
forall (st' :: ps). Peer role' ps r m (At a (Done r)) st' -> m a
go
  go (Yield (Msg role' ps st' '(r, sps) '(recv, rps)
msg :: Msg role' ps (st' :: ps) '(r, sps) '(recv :: role', rps)) Peer role' ps r m (At a (Done r)) sps
k) = do
    Sing recv -> Msg role' ps st' '(r, sps) '(recv, rps) -> m ()
forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg (forall (a :: role'). SingI a => Sing a
forall {k} (a :: k). SingI a => Sing a
sing @recv) Msg role' ps st' '(r, sps) '(recv, rps)
msg
    Peer role' ps r m (At a (Done r)) sps -> m a
forall (st' :: ps). Peer role' ps r m (At a (Done r)) st' -> m a
go Peer role' ps r m (At a (Done r)) sps
k
  go (Await (Recv role' ps r st' ~> Peer role' ps r m (At a (Done r))
k :: (Recv role' ps r st' I.~> Peer role' ps r m ia))) = do
    AnyMsg msg <- Sing st' -> m (AnyMsg role' ps)
forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg (forall (a :: ps). SingI a => Sing a
forall {k} (a :: k). SingI a => Sing a
sing @st')
    go (k $ unsafeCoerce (Recv msg))

{- |
A wrapper around AnyMsg that represents sending and receiving Msg.
-}
data TraceSendRecv role' ps where
  TraceSendMsg :: AnyMsg role' ps -> TraceSendRecv role' ps
  TraceRecvMsg :: AnyMsg role' ps -> TraceSendRecv role' ps

instance (Show (AnyMsg role' ps)) => Show (TraceSendRecv role' ps) where
  show :: TraceSendRecv role' ps -> String
show (TraceSendMsg AnyMsg role' ps
msg) = String
"Send " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AnyMsg role' ps -> String
forall a. Show a => a -> String
show AnyMsg role' ps
msg
  show (TraceRecvMsg AnyMsg role' ps
msg) = String
"Recv " String -> ShowS
forall a. [a] -> [a] -> [a]
++ AnyMsg role' ps -> String
forall a. Show a => a -> String
show AnyMsg role' ps
msg

{- |
Similar to the log function, used to print received or sent messages.
-}
type Tracer role' ps m = TraceSendRecv role' ps -> m ()

{- |
The default trace function. It simply ignores everything.
-}
nullTracer :: (Monad m) => a -> m ()
nullTracer :: forall (m :: * -> *) a. Monad m => a -> m ()
nullTracer a
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

{- |
SendMap aggregates the send functions of multiple Channels together.
When sending a message, the send function of the receiver is found from SendMap.
-}
type SendMap role' m bytes = IntMap (bytes -> m ())

{- |

Build Driver through SendMap and MsgCache.
Here we need some help from other functions:

1. `Tracer role' ps n` is similar to the log function, used to print received or sent messages.
2. `Encode role' ps` bytes encoding function, converts Msg into bytes.
3. `forall a. n a -> m a` This is a bit complicated, I will explain it in detail below.

I see Peer as three layers:

1. `Peer` upper layer, meets the requirements of McBride Indexed Monad, uses do syntax construction, has semantic checks, and is interpreted to the second layer m through runPeerWithDriver.
2. `m` middle layer, describes the business requirements in this layer, and converts the received Msg into specific business actions.
3. `n` bottom layer, responsible for receiving and sending bytes. It can have multiple options such as IO or IOSim. Using IOSim can easily test the code.
-}
driverSimple
  :: forall role' ps bytes m n
   . ( Monad m
     , Monad n
     , MonadSTM n
     )
  => Tracer role' ps n
  -> Encode role' ps bytes
  -> SendMap role' n bytes
  -> TVar n (MsgCache role' ps)
  -> (forall a. n a -> m a)
  -> Driver role' ps m
driverSimple :: forall role' ps bytes (m :: * -> *) (n :: * -> *).
(Monad m, Monad n, MonadSTM n) =>
Tracer role' ps n
-> Encode role' ps bytes
-> SendMap role' n bytes
-> TVar n (MsgCache role' ps)
-> (forall a. n a -> m a)
-> Driver role' ps m
driverSimple Tracer role' ps n
tracer Encode{forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
Msg role' ps st '(send, st') '(recv, st'') -> bytes
encode :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
Msg role' ps st '(send, st') '(recv, st'') -> bytes
encode :: forall role' ps bytes.
Encode role' ps bytes
-> forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
          (st'' :: ps).
   Msg role' ps st '(send, st') '(recv, st'') -> bytes
encode} SendMap role' n bytes
sendMap TVar n (MsgCache role' ps)
tvar forall a. n a -> m a
liftFun =
  Driver{Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg, Sing st' -> m (AnyMsg role' ps)
forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg :: forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg :: forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg}
 where
  sendMsg
    :: forall (send :: role') (recv :: role') (from :: ps) (st :: ps) (st1 :: ps)
     . ( SingI recv
       , SingI from
       , SingToInt ps
       , SingToInt role'
       )
    => Sing recv
    -> Msg role' ps from '(send, st) '(recv, st1)
    -> m ()
  sendMsg :: forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
(SingI recv, SingI st, SingToInt ps, SingToInt role') =>
Sing recv -> Msg role' ps st '(send, st') '(recv, st'') -> m ()
sendMsg Sing recv
role Msg role' ps from '(send, st) '(recv, st1)
msg = n () -> m ()
forall a. n a -> m a
liftFun (n () -> m ()) -> n () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    case Int -> SendMap role' n bytes -> Maybe (bytes -> n ())
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup (Sing recv -> Int
forall (r :: role'). Sing r -> Int
forall s (r :: s). SingToInt s => Sing r -> Int
singToInt Sing recv
role) SendMap role' n bytes
sendMap of
      Maybe (bytes -> n ())
Nothing -> String -> n ()
forall a. HasCallStack => String -> a
error String
"np"
      Just bytes -> n ()
sendFun -> bytes -> n ()
sendFun (Msg role' ps from '(send, st) '(recv, st1) -> bytes
forall (send :: role') (recv :: role') (st :: ps) (st' :: ps)
       (st'' :: ps).
Msg role' ps st '(send, st') '(recv, st'') -> bytes
encode Msg role' ps from '(send, st) '(recv, st1)
msg)
    Tracer role' ps n
tracer (AnyMsg role' ps -> TraceSendRecv role' ps
forall role' ps. AnyMsg role' ps -> TraceSendRecv role' ps
TraceSendMsg (Msg role' ps from '(send, st) '(recv, st1) -> AnyMsg role' ps
forall role' (recv :: role') ps (st :: ps) (send :: role')
       (st' :: ps) (st'' :: ps).
(SingI recv, SingI st, SingToInt role', SingToInt ps) =>
Msg role' ps st '(send, st') '(recv, st'') -> AnyMsg role' ps
AnyMsg Msg role' ps from '(send, st) '(recv, st1)
msg))

  recvMsg
    :: forall (st' :: ps)
     . (SingToInt ps)
    => Sing st'
    -> m (AnyMsg role' ps)
  recvMsg :: forall (st' :: ps). SingToInt ps => Sing st' -> m (AnyMsg role' ps)
recvMsg Sing st'
sst' = do
    let singInt :: Int
singInt = Sing st' -> Int
forall (r :: ps). Sing r -> Int
forall s (r :: s). SingToInt s => Sing r -> Int
singToInt Sing st'
sst'
    n (AnyMsg role' ps) -> m (AnyMsg role' ps)
forall a. n a -> m a
liftFun (n (AnyMsg role' ps) -> m (AnyMsg role' ps))
-> n (AnyMsg role' ps) -> m (AnyMsg role' ps)
forall a b. (a -> b) -> a -> b
$ do
      anyMsg <- STM n (AnyMsg role' ps) -> n (AnyMsg role' ps)
forall a. HasCallStack => STM n a -> n a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM n (AnyMsg role' ps) -> n (AnyMsg role' ps))
-> STM n (AnyMsg role' ps) -> n (AnyMsg role' ps)
forall a b. (a -> b) -> a -> b
$ do
        agencyMsg <- TVar n (MsgCache role' ps) -> STM n (MsgCache role' ps)
forall a. TVar n a -> STM n a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar n (MsgCache role' ps)
tvar
        case IntMap.lookup singInt agencyMsg of
          Maybe (AnyMsg role' ps)
Nothing -> STM n (AnyMsg role' ps)
forall a. STM n a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
          Just AnyMsg role' ps
v -> do
            TVar n (MsgCache role' ps) -> MsgCache role' ps -> STM n ()
forall a. TVar n a -> a -> STM n ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar n (MsgCache role' ps)
tvar (Int -> MsgCache role' ps -> MsgCache role' ps
forall a. Int -> IntMap a -> IntMap a
IntMap.delete Int
singInt MsgCache role' ps
agencyMsg)
            AnyMsg role' ps -> STM n (AnyMsg role' ps)
forall a. a -> STM n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AnyMsg role' ps
v
      tracer (TraceRecvMsg (anyMsg))
      pure anyMsg

{- |
decode loop, usually in a separate thread.

The decoded Msg is placed in MsgCache.

@
data Msg role' ps (from :: ps) (sendAndSt :: (role', ps)) (recvAndSt :: (role', ps))
@
Note that when placing a new Msg in MsgCache, if a Msg with the same `from` already exists in MsgCache, the decoding process will be blocked,
until that Msg is consumed before placing the new Msg in MsgCache.

This usually happens when the efficiency of Msg generation is greater than the efficiency of consumption.
-}
decodeLoop
  :: (Exception failure, MonadSTM n, MonadThrow n)
  => Tracer role' ps n
  -> Maybe bytes
  -> Decode role' ps failure bytes
  -> Channel n bytes
  -> TVar n (MsgCache role' ps)
  -> n ()
decodeLoop :: forall failure (n :: * -> *) role' ps bytes.
(Exception failure, MonadSTM n, MonadThrow n) =>
Tracer role' ps n
-> Maybe bytes
-> Decode role' ps failure bytes
-> Channel n bytes
-> TVar n (MsgCache role' ps)
-> n ()
decodeLoop Tracer role' ps n
tracer Maybe bytes
mbt d :: Decode role' ps failure bytes
d@Decode{DecodeStep bytes failure (AnyMsg role' ps)
decode :: DecodeStep bytes failure (AnyMsg role' ps)
decode :: forall role' ps failure bytes.
Decode role' ps failure bytes
-> DecodeStep bytes failure (AnyMsg role' ps)
decode} Channel n bytes
channel TVar n (MsgCache role' ps)
tvar = do
  result <- Channel n bytes
-> Maybe bytes
-> DecodeStep bytes failure (AnyMsg role' ps)
-> n (Either failure (AnyMsg role' ps, Maybe bytes))
forall (m :: * -> *) bytes failure a.
Monad m =>
Channel m bytes
-> Maybe bytes
-> DecodeStep bytes failure a
-> m (Either failure (a, Maybe bytes))
runDecoderWithChannel Channel n bytes
channel Maybe bytes
mbt DecodeStep bytes failure (AnyMsg role' ps)
decode
  case result of
    Right (AnyMsg Msg role' ps st '(send, st') '(recv, st'')
msg, Maybe bytes
mbt') -> do
      let agencyInt :: Int
agencyInt = Sing st -> Int
forall (r :: ps). Sing r -> Int
forall s (r :: s). SingToInt s => Sing r -> Int
singToInt (Sing st -> Int) -> Sing st -> Int
forall a b. (a -> b) -> a -> b
$ Msg role' ps st '(send, st') '(recv, st'') -> Sing st
forall role' ps (st :: ps) (send :: role') (recv :: role')
       (st' :: ps) (st'' :: ps).
(SingI recv, SingI st) =>
Msg role' ps st '(send, st') '(recv, st'') -> Sing st
msgFromStSing Msg role' ps st '(send, st') '(recv, st'')
msg
      STM n () -> n ()
forall a. HasCallStack => STM n a -> n a
forall (m :: * -> *) a.
(MonadSTM m, HasCallStack) =>
STM m a -> m a
atomically (STM n () -> n ()) -> STM n () -> n ()
forall a b. (a -> b) -> a -> b
$ do
        agencyMsg <- TVar n (MsgCache role' ps) -> STM n (MsgCache role' ps)
forall a. TVar n a -> STM n a
forall (m :: * -> *) a. MonadSTM m => TVar m a -> STM m a
readTVar TVar n (MsgCache role' ps)
tvar
        case IntMap.lookup agencyInt agencyMsg of
          Maybe (AnyMsg role' ps)
Nothing -> TVar n (MsgCache role' ps) -> MsgCache role' ps -> STM n ()
forall a. TVar n a -> a -> STM n ()
forall (m :: * -> *) a. MonadSTM m => TVar m a -> a -> STM m ()
writeTVar TVar n (MsgCache role' ps)
tvar (Int -> AnyMsg role' ps -> MsgCache role' ps -> MsgCache role' ps
forall a. Int -> a -> IntMap a -> IntMap a
IntMap.insert Int
agencyInt (Msg role' ps st '(send, st') '(recv, st'') -> AnyMsg role' ps
forall role' (recv :: role') ps (st :: ps) (send :: role')
       (st' :: ps) (st'' :: ps).
(SingI recv, SingI st, SingToInt role', SingToInt ps) =>
Msg role' ps st '(send, st') '(recv, st'') -> AnyMsg role' ps
AnyMsg Msg role' ps st '(send, st') '(recv, st'')
msg) MsgCache role' ps
agencyMsg)
          Just AnyMsg role' ps
_v -> STM n ()
forall a. STM n a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
      Tracer role' ps n
-> Maybe bytes
-> Decode role' ps failure bytes
-> Channel n bytes
-> TVar n (MsgCache role' ps)
-> n ()
forall failure (n :: * -> *) role' ps bytes.
(Exception failure, MonadSTM n, MonadThrow n) =>
Tracer role' ps n
-> Maybe bytes
-> Decode role' ps failure bytes
-> Channel n bytes
-> TVar n (MsgCache role' ps)
-> n ()
decodeLoop Tracer role' ps n
tracer Maybe bytes
mbt' Decode role' ps failure bytes
d Channel n bytes
channel TVar n (MsgCache role' ps)
tvar
    Left failure
failure -> failure -> n ()
forall e a. Exception e => e -> n a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO failure
failure