{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeApplications #-}

-- | This module provides basic support for asynchronous communication

-- and computation of secret-shared values.

module Asyncoro (createConnections, send, receive, Gather(..), async, asyncList, asyncListList, await, incPC, decreaseBarrier) where
import Network.Socket
import Network.Socket.ByteString (recv, sendAll)
import Control.Exception
import Control.Concurrent
import System.IO.Error
import Control.Monad
import Types
import Data.Function
import Data.List
import qualified Data.Map.Strict as Map
import qualified Data.Serialize as Enc
import qualified Data.ByteString as BS
import Data.Hashable
import Control.Monad.State
import SecTypes
import FinFields
import System.Log.Logger
import Text.Printf
import Parser


-- | Open connections with other parties, if any.

createConnections :: Int -> [Party] -> IO [Party]
createConnections :: Int -> [Party] -> IO [Party]
createConnections Int
myPid [Party]
parties = do
    let m :: Int
m = [Party] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Party]
parties
    let listenPort :: Integer
listenPort = Party -> Integer
port (Party -> Integer) -> Party -> Integer
forall a b. (a -> b) -> a -> b
$ [Party]
parties [Party] -> Int -> Party
forall a. HasCallStack => [a] -> Int -> a
!! Int
myPid
    Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET SocketType
Stream ProtocolNumber
0    -- create socket

    Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1   -- make socket immediately reusable - eases debugging.

    Socket -> SockAddr -> IO ()
bind Socket
sock (PortNumber -> HostAddress -> SockAddr
SockAddrInet (Integer -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
listenPort) HostAddress
0)   -- listen on TCP port 4242 + pid.

    Socket -> Int -> IO ()
listen Socket
sock Int
1                             -- set a max of 2 queued connections 

    [MVar Party]
serverParties <- Int -> IO (MVar Party) -> IO [MVar Party]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
myPid (IO (MVar Party) -> IO [MVar Party])
-> IO (MVar Party) -> IO [MVar Party]
forall a b. (a -> b) -> a -> b
$ do
        MVar Party
mvar <- IO (MVar Party)
forall a. IO (MVar a)
newEmptyMVar
        IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Socket -> [Party] -> MVar Party -> IO ()
connectServer Socket
sock [Party]
parties MVar Party
mvar
        MVar Party -> IO (MVar Party)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return MVar Party
mvar
    
    [MVar Party]
clientParties <- [Party] -> (Party -> IO (MVar Party)) -> IO [MVar Party]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Int -> [Party] -> [Party]
forall a. Int -> [a] -> [a]
drop (Int
myPidInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Party]
parties) ((Party -> IO (MVar Party)) -> IO [MVar Party])
-> (Party -> IO (MVar Party)) -> IO [MVar Party]
forall a b. (a -> b) -> a -> b
$ \Party
party -> do
      MVar Party
mvar <- IO (MVar Party)
forall a. IO (MVar a)
newEmptyMVar
      IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Int -> Party -> MVar Party -> IO ()
forall {t}. Serialize t => t -> Party -> MVar Party -> IO ()
connectClient Int
myPid Party
party MVar Party
mvar
      MVar Party -> IO (MVar Party)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return MVar Party
mvar

    [Party]
channels <- (MVar Party -> IO Party) -> [MVar Party] -> IO [Party]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM MVar Party -> IO Party
forall a. MVar a -> IO a
takeMVar ([MVar Party]
serverParties [MVar Party] -> [MVar Party] -> [MVar Party]
forall a. [a] -> [a] -> [a]
++ [MVar Party]
clientParties)
    Priority -> String -> IO ()
logging Priority
INFO (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"All %d parties connected." Int
m

    Socket -> IO ()
close Socket
sock

    -- necessary if --single-threaded bug

    Int
cap <- IO Int
getNumCapabilities
    Priority -> String -> IO ()
logging Priority
INFO (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Int -> String
forall r. PrintfType r => String -> r
printf String
"All threads run on %d cores." Int
cap
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
cap Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1) (Int -> IO ()
threadDelay Int
5000000)
    
    MVar (Map Int (MVar ByteString))
newDict <- Map Int (MVar ByteString) -> IO (MVar (Map Int (MVar ByteString)))
forall a. a -> IO (MVar a)
newMVar Map Int (MVar ByteString)
forall k a. Map k a
Map.empty
    MVar Int
newMVar <- Int -> IO (MVar Int)
forall a. a -> IO (MVar a)
newMVar Int
0
    Chan ByteString
emptyChan <- IO (Chan ByteString)
forall a. IO (Chan a)
newChan
    let selfParty :: Party
selfParty = ([Party]
parties [Party] -> Int -> Party
forall a. HasCallStack => [a] -> Int -> a
!! Int
myPid){outChan = emptyChan, sock = Nothing, dict = newDict, nbytesSent=newMVar}
    [Party] -> IO [Party]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Party] -> IO [Party]) -> [Party] -> IO [Party]
forall a b. (a -> b) -> a -> b
$ (Party -> Party -> Ordering) -> [Party] -> [Party]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Integer -> Integer -> Ordering)
-> (Party -> Integer) -> Party -> Party -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Party -> Integer
pid) (Party
selfPartyParty -> [Party] -> [Party]
forall a. a -> [a] -> [a]
: [Party]
channels)

  where 
    connectServer :: Socket -> [Party] -> MVar Party -> IO ()
connectServer Socket
sock [Party]
parties MVar Party
mvar = do
      (Socket
conn, SockAddr
_) <- Socket -> IO (Socket, SockAddr)
accept Socket
sock     -- accept a connection and handle it

      ByteString
msg <- Socket -> Int -> IO ByteString
recv Socket
conn Int
1024 -- receive pid

      case ByteString -> Either String Int
forall a. Serialize a => ByteString -> Either String a
Enc.decode ByteString
msg of
        Right Int
peer_pid -> Socket -> Party -> MVar Party -> IO ()
initConnection Socket
conn ([Party]
parties [Party] -> Int -> Party
forall a. HasCallStack => [a] -> Int -> a
!! Int
peer_pid) MVar Party
mvar

    connectClient :: t -> Party -> MVar Party -> IO ()
connectClient t
pid Party
peer MVar Party
mvar = do
      AddrInfo
addr <- [AddrInfo] -> AddrInfo
forall a. HasCallStack => [a] -> a
head ([AddrInfo] -> AddrInfo) -> IO [AddrInfo] -> IO AddrInfo
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
getAddrInfo Maybe AddrInfo
forall a. Maybe a
Nothing (String -> Maybe String
forall a. a -> Maybe a
Just (Party -> String
host Party
peer)) (String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ Integer -> String
forall a. Show a => a -> String
show (Party -> Integer
port Party
peer))
      Socket
sock <- AddrInfo -> IO Socket
openSocket AddrInfo
addr
      Either IOException ()
res <- forall e a. Exception e => IO a -> IO (Either e a)
try @IOException (IO () -> IO (Either IOException ()))
-> IO () -> IO (Either IOException ())
forall a b. (a -> b) -> a -> b
$ Socket -> SockAddr -> IO ()
connect Socket
sock (AddrInfo -> SockAddr
addrAddress AddrInfo
addr)
      case Either IOException ()
res of
        Right ()
_ -> do  -- connection successful

          Socket -> ByteString -> IO ()
sendAll Socket
sock (t -> ByteString
forall a. Serialize a => a -> ByteString
Enc.encode t
pid) --send pid

          Socket -> Party -> MVar Party -> IO ()
initConnection Socket
sock Party
peer MVar Party
mvar
        Left IOException
_ -> do  -- exception

          Int -> IO ()
threadDelay Int
100000
          t -> Party -> MVar Party -> IO ()
connectClient t
pid Party
peer MVar Party
mvar

    initConnection :: Socket -> Party -> MVar Party -> IO ()
initConnection Socket
sock Party
peer MVar Party
mvar = do
      MVar (Map Int (MVar ByteString))
newDict <- Map Int (MVar ByteString) -> IO (MVar (Map Int (MVar ByteString)))
forall a. a -> IO (MVar a)
newMVar Map Int (MVar ByteString)
forall k a. Map k a
Map.empty
      MVar Int
bytesMvar <- Int -> IO (MVar Int)
forall a. a -> IO (MVar a)
newMVar Int
1 -- one for the runSession to complete

      Chan ByteString
outChan <- IO (Chan ByteString)
forall a. IO (Chan a)
newChan
      MVar Party -> Party -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Party
mvar Party
peer{outChan = outChan, sock = Just sock, dict = newDict, nbytesSent=bytesMvar}
      Chan ByteString
-> Socket -> MVar (Map Int (MVar ByteString)) -> MVar Int -> IO ()
runConnection Chan ByteString
outChan Socket
sock MVar (Map Int (MVar ByteString))
newDict MVar Int
bytesMvar


-- read lines from the socket and insert into dictionary

runConnection :: Chan BS.ByteString -> Socket -> MVar Dict -> MVar Int -> IO ()
runConnection :: Chan ByteString
-> Socket -> MVar (Map Int (MVar ByteString)) -> MVar Int -> IO ()
runConnection Chan ByteString
chan Socket
sock MVar (Map Int (MVar ByteString))
dictMvar MVar Int
nbytesSent = do
    ThreadId
reader <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        ByteString
dataToSend <- Chan ByteString -> IO ByteString
forall a. Chan a -> IO a
readChan Chan ByteString
chan
        Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
dataToSend

    (SomeException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (\(SomeException e
_) -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ (((ByteString -> IO ()) -> ByteString -> IO ())
 -> ByteString -> IO ())
-> ByteString
-> ((ByteString -> IO ()) -> ByteString -> IO ())
-> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((ByteString -> IO ()) -> ByteString -> IO ())
-> ByteString -> IO ()
forall a. (a -> a) -> a
fix ByteString
BS.empty (((ByteString -> IO ()) -> ByteString -> IO ()) -> IO ())
-> ((ByteString -> IO ()) -> ByteString -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ByteString -> IO ()
loop ByteString
buffer_old ->
      ByteString -> ByteString -> ByteString
BS.append ByteString
buffer_old (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Int -> IO ByteString
recv Socket
sock Int
1024 
        IO ByteString -> (ByteString -> IO ByteString) -> IO ByteString
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVar (Map Int (MVar ByteString))
-> MVar Int -> ByteString -> IO ByteString
decodeMessageChecks MVar (Map Int (MVar ByteString))
dictMvar MVar Int
nbytesSent 
        IO ByteString -> (ByteString -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> IO ()
loop
        IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    
    ThreadId -> IO ()
killThread ThreadId
reader

    where 
      decodeMessageChecks :: MVar (Map Int (MVar ByteString))
-> MVar Int -> ByteString -> IO ByteString
decodeMessageChecks MVar (Map Int (MVar ByteString))
dictMvar MVar Int
nbytesSent ByteString
buffer = do
        let bufferLength :: Int
bufferLength = ByteString -> Int
BS.length ByteString
buffer
        if Int
bufferLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4
          then ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
buffer
          else do
            let payload_length :: Int
payload_length = Get Int32 -> ByteString -> Int
forall {a1} {a}. (Integral a1, Num a) => Get a1 -> ByteString -> a
_decode Get Int32
Enc.getInt32le (Int -> ByteString -> ByteString
BS.take Int
4 ByteString
buffer)
                len_packet :: Int
len_packet = Int
payload_length Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
12
            if Int
bufferLength Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len_packet
              then ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
buffer
              else ByteString -> Int -> IO ByteString
decodeMessage ByteString
buffer Int
len_packet

      decodeMessage :: ByteString -> Int -> IO ByteString
decodeMessage ByteString
buffer Int
len_packet = do
        let (ByteString
msg, ByteString
leftover) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
len_packet ByteString
buffer
            pc :: Int
pc = Get Int64 -> ByteString -> Int
forall {a1} {a}. (Integral a1, Num a) => Get a1 -> ByteString -> a
_decode Get Int64
Enc.getInt64le (Int -> ByteString -> ByteString
BS.drop Int
4 ByteString
buffer)
        MVar Int -> (Int -> IO Int) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar Int
nbytesSent (Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> IO Int) -> (Int -> Int) -> Int -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len_packet))
        MVar (Map Int (MVar ByteString))
-> (Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString)))
-> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar (Map Int (MVar ByteString))
dictMvar ((Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString)))
 -> IO ())
-> (Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString)))
-> IO ()
forall a b. (a -> b) -> a -> b
$ \Map Int (MVar ByteString)
dict ->
          case Int -> Map Int (MVar ByteString) -> Maybe (MVar ByteString)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Int
pc Map Int (MVar ByteString)
dict of
            Just MVar ByteString
mvar -> do
              MVar ByteString -> ByteString -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ByteString
mvar (Int -> ByteString -> ByteString
BS.drop Int
12 ByteString
msg)
              Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString)))
-> Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString))
forall a b. (a -> b) -> a -> b
$ Int -> Map Int (MVar ByteString) -> Map Int (MVar ByteString)
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete Int
pc Map Int (MVar ByteString)
dict
            Maybe (MVar ByteString)
Nothing -> do
              MVar ByteString
mvar <- ByteString -> IO (MVar ByteString)
forall a. a -> IO (MVar a)
newMVar (Int -> ByteString -> ByteString
BS.drop Int
12 ByteString
msg)
              Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString)))
-> Map Int (MVar ByteString) -> IO (Map Int (MVar ByteString))
forall a b. (a -> b) -> a -> b
$ Int
-> MVar ByteString
-> Map Int (MVar ByteString)
-> Map Int (MVar ByteString)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
pc MVar ByteString
mvar Map Int (MVar ByteString)
dict                                                    
        MVar (Map Int (MVar ByteString))
-> MVar Int -> ByteString -> IO ByteString
decodeMessageChecks MVar (Map Int (MVar ByteString))
dictMvar MVar Int
nbytesSent ByteString
leftover

      _decode :: Get a1 -> ByteString -> a
_decode Get a1
decoder ByteString
buffer = 
        case Get a -> ByteString -> Either String a
forall a. Get a -> ByteString -> Either String a
Enc.runGet ((a1 -> a) -> Get a1 -> Get a
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM a1 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral Get a1
decoder) ByteString
buffer of
          Right a
t -> a
t

-- | Receive payload labeled with given pc from the peer.

receive :: Int -> Party -> SIO (MVar BS.ByteString)
receive :: Int -> Party -> SIO (MVar ByteString)
receive Int
pc Party
party = IO (MVar ByteString) -> SIO (MVar ByteString)
forall a. IO a -> StateT Env IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (MVar ByteString) -> SIO (MVar ByteString))
-> IO (MVar ByteString) -> SIO (MVar ByteString)
forall a b. (a -> b) -> a -> b
$ MVar (Map Int (MVar ByteString))
-> (Map Int (MVar ByteString)
    -> IO (Map Int (MVar ByteString), MVar ByteString))
-> IO (MVar ByteString)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Party -> MVar (Map Int (MVar ByteString))
dict Party
party) ((Map Int (MVar ByteString)
  -> IO (Map Int (MVar ByteString), MVar ByteString))
 -> IO (MVar ByteString))
-> (Map Int (MVar ByteString)
    -> IO (Map Int (MVar ByteString), MVar ByteString))
-> IO (MVar ByteString)
forall a b. (a -> b) -> a -> b
$ \Map Int (MVar ByteString)
dict -> do
    case Int -> Map Int (MVar ByteString) -> Maybe (MVar ByteString)
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Int
pc Map Int (MVar ByteString)
dict of
        Just MVar ByteString
value -> (Map Int (MVar ByteString), MVar ByteString)
-> IO (Map Int (MVar ByteString), MVar ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Map Int (MVar ByteString) -> Map Int (MVar ByteString)
forall k a. Ord k => k -> Map k a -> Map k a
Map.delete Int
pc Map Int (MVar ByteString)
dict, MVar ByteString
value)
        Maybe (MVar ByteString)
Nothing -> do
            MVar ByteString
mvar <- IO (MVar ByteString)
forall a. IO (MVar a)
newEmptyMVar
            (Map Int (MVar ByteString), MVar ByteString)
-> IO (Map Int (MVar ByteString), MVar ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((Int
-> MVar ByteString
-> Map Int (MVar ByteString)
-> Map Int (MVar ByteString)
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Int
pc MVar ByteString
mvar Map Int (MVar ByteString)
dict), MVar ByteString
mvar)

-- | Transform 'SecureTypes' into 'FiniteField' by reading the future 'MVar' share that contains a 'FiniteField' (blocking).

class Gather a where
  type Result a :: *
  gather :: a -> SIO (Result a)

instance Gather SecureTypes where
  type Result SecureTypes = FiniteField
  gather :: SecureTypes -> SIO (Result SecureTypes)
gather = MVar FiniteField -> SIO FiniteField
forall a. MVar a -> SIO a
await (MVar FiniteField -> SIO FiniteField)
-> (SecureTypes -> MVar FiniteField)
-> SecureTypes
-> SIO FiniteField
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecureTypes -> MVar FiniteField
share

instance Gather a => Gather [a] where
  type Result [a] = [Result a]
  gather :: [a] -> SIO (Result [a])
gather = (a -> StateT Env IO (Result a)) -> [a] -> StateT Env IO [Result a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM a -> StateT Env IO (Result a)
forall a. Gather a => a -> SIO (Result a)
gather

instance (Gather a, Gather b) => Gather (a, b) where
  type Result (a, b) = (Result a, Result b)
  gather :: (a, b) -> SIO (Result (a, b))
gather (a
x, b
y) = do
    Result a
resultX <- a -> StateT Env IO (Result a)
forall a. Gather a => a -> SIO (Result a)
gather a
x
    Result b
resultY <- b -> StateT Env IO (Result b)
forall a. Gather a => a -> SIO (Result a)
gather b
y
    (Result a, Result b) -> StateT Env IO (Result a, Result b)
forall a. a -> StateT Env IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result a
resultX, Result b
resultY)

instance (Gather a, Gather b, Gather c) => Gather (a, b, c) where
  type Result (a, b, c) = (Result a, Result b, Result c)
  gather :: (a, b, c) -> SIO (Result (a, b, c))
gather (a
x, b
y, c
z) = do
    Result a
resultX <- a -> StateT Env IO (Result a)
forall a. Gather a => a -> SIO (Result a)
gather a
x
    Result b
resultY <- b -> StateT Env IO (Result b)
forall a. Gather a => a -> SIO (Result a)
gather b
y
    Result c
resultZ <- c -> StateT Env IO (Result c)
forall a. Gather a => a -> SIO (Result a)
gather c
z
    (Result a, Result b, Result c)
-> StateT Env IO (Result a, Result b, Result c)
forall a. a -> StateT Env IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result a
resultX, Result b
resultY, Result c
resultZ)

-- | Read the value from the future MVar (blocking).

await :: MVar a -> SIO a
await :: forall a. MVar a -> SIO a
await = IO a -> StateT Env IO a
forall a. IO a -> StateT Env IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> StateT Env IO a)
-> (MVar a -> IO a) -> MVar a -> StateT Env IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar a -> IO a
forall a. MVar a -> IO a
readMVar


-- | Send payload labeled with pc to the peer.

--

-- Message format consists of three parts:

--

-- 1. pc (8 bytes signed int)

--

-- 2. payload_size (4 bytes unsigned int)

-- 

-- 3. payload (byte string of length payload_size).

send :: Int -> BS.ByteString -> Party -> SIO ()
send :: Int -> ByteString -> Party -> SIO ()
send Int
pc ByteString
payload Party
party = do
  let payload_size :: Int
payload_size = (ByteString -> Int
BS.length ByteString
payload)
  let bytes :: ByteString
bytes = ((Put -> ByteString
Enc.runPut ( 
        (Putter Int32
Enc.putInt32le Putter Int32 -> (Int -> Int32) -> Int -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral) Int
payload_size
        Put -> Put -> Put
forall a b. PutM a -> PutM b -> PutM b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Putter Int64
Enc.putInt64le Putter Int64 -> (Int -> Int64) -> Int -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral) Int
pc)) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
payload)
  IO () -> SIO ()
forall a. IO a -> StateT Env IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> SIO ()) -> IO () -> SIO ()
forall a b. (a -> b) -> a -> b
$ Chan ByteString -> ByteString -> IO ()
forall a. Chan a -> a -> IO ()
writeChan (Party -> Chan ByteString
outChan Party
party) ByteString
bytes

-- | increment program counter in state.

incPC :: SIO Int
incPC :: SIO Int
incPC = do
    Int
pcOld <- (Env -> Int) -> SIO Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Env -> Int
pc
    (Env -> Env) -> SIO ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\Env
env -> Env
env{pc = (+1) pcOld})
    (Env -> Int) -> SIO Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Env -> Int
pc

-- | 'forkIO' the action monad asynchronously and return future 'MVar'.

-- Provide the given state monad with its own program counter space.

async :: SIO a -> SIO (MVar a)
async :: forall a. SIO a -> SIO (MVar a)
async = \SIO a
action -> [MVar a] -> MVar a
forall a. HasCallStack => [a] -> a
head ([MVar a] -> MVar a) -> StateT Env IO [MVar a] -> SIO (MVar a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> SIO [a] -> StateT Env IO [MVar a]
forall a. Int -> SIO [a] -> SIO [MVar a]
asyncList Int
1 ((a -> [a] -> [a]
forall a. a -> [a] -> [a]
:[]) (a -> [a]) -> SIO a -> SIO [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SIO a
action)

asyncList :: Int -> SIO [a] -> SIO [MVar a]
asyncList :: forall a. Int -> SIO [a] -> SIO [MVar a]
asyncList Int
l = \SIO [a]
action -> [[MVar a]] -> [MVar a]
forall a. HasCallStack => [a] -> a
head ([[MVar a]] -> [MVar a])
-> StateT Env IO [[MVar a]] -> SIO [MVar a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> SIO [[a]] -> StateT Env IO [[MVar a]]
forall a. Int -> Int -> SIO [[a]] -> SIO [[MVar a]]
asyncListList Int
1 Int
l (([a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
:[]) ([a] -> [[a]]) -> SIO [a] -> SIO [[a]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SIO [a]
action)

asyncListList :: Int -> Int -> SIO [[a]] -> SIO [[MVar a]]
asyncListList :: forall a. Int -> Int -> SIO [[a]] -> SIO [[MVar a]]
asyncListList Int
l1 Int
l2 = \SIO [[a]]
action -> do
    Int
pcOld <- SIO Int
incPC
    Env
state <- StateT Env IO Env
forall s (m :: * -> *). MonadState s m => m s
get
    [[MVar a]]
outslist <- Int -> StateT Env IO [MVar a] -> SIO [[MVar a]]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
l1 (StateT Env IO [MVar a] -> SIO [[MVar a]])
-> StateT Env IO [MVar a] -> SIO [[MVar a]]
forall a b. (a -> b) -> a -> b
$ Int -> StateT Env IO (MVar a) -> StateT Env IO [MVar a]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
l2 (IO (MVar a) -> StateT Env IO (MVar a)
forall a. IO a -> StateT Env IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (MVar a) -> StateT Env IO (MVar a))
-> IO (MVar a) -> StateT Env IO (MVar a)
forall a b. (a -> b) -> a -> b
$ IO (MVar a)
forall a. IO (MVar a)
newEmptyMVar)

    let newState :: Env
newState = Env
state{pc = hash $ show pcOld}
        barrier :: Barrier
barrier = Env -> Barrier
forkIOBarrier Env
state
        action2 :: IO ()
action2 = do
          [[a]]
fieldslist <- SIO [[a]] -> Env -> IO [[a]]
forall a. SIO a -> Env -> IO a
runSIO SIO [[a]]
action Env
newState
          ([MVar a] -> [a] -> IO ()) -> [[MVar a]] -> [[a]] -> IO ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ((MVar a -> a -> IO ()) -> [MVar a] -> [a] -> IO ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar) [[MVar a]]
outslist [[a]]
fieldslist
          Barrier -> IO ()
decreaseBarrier Barrier
barrier    

    IO () -> SIO ()
forall a. IO a -> StateT Env IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> SIO ()) -> IO () -> SIO ()
forall a b. (a -> b) -> a -> b
$ MVar Int -> (Int -> IO Int) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Barrier -> MVar Int
count Barrier
barrier) (Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> IO Int) -> (Int -> Int) -> Int -> IO Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
    if (Options -> Bool
noAsync (Options -> Bool) -> (Env -> Options) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Options
options) Env
state
      then IO () -> SIO ()
forall a. IO a -> StateT Env IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> SIO ()) -> IO () -> SIO ()
forall a b. (a -> b) -> a -> b
$ IO ()
action2
      else StateT Env IO ThreadId -> SIO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (StateT Env IO ThreadId -> SIO ())
-> StateT Env IO ThreadId -> SIO ()
forall a b. (a -> b) -> a -> b
$ IO ThreadId -> StateT Env IO ThreadId
forall a. IO a -> StateT Env IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ThreadId -> StateT Env IO ThreadId)
-> IO ThreadId -> StateT Env IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO ()
action2

    [[MVar a]] -> SIO [[MVar a]]
forall a. a -> StateT Env IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [[MVar a]]
outslist 

decreaseBarrier :: Barrier -> IO ()
decreaseBarrier :: Barrier -> IO ()
decreaseBarrier (Barrier MVar Int
countVar MVar ()
signalVar) =
          MVar Int -> (Int -> IO Int) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar Int
countVar ((Int -> IO Int) -> IO ()) -> (Int -> IO Int) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
n -> do
            let n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
            if Int
n' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
              then do
                MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
signalVar ()
                Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n'
              else Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n'