module Network.RPC (
Method,
newCallSite,
CallSite,
call,
callWithTimeout,
gcall,
gcallWithTimeout,
anyCall,
methodSelector,
hear,
hearTimeout,
hearAll,
hearAllTimeout,
Reply,
HandleSite,
handle,
handleAll,
hangup,
Request(..),
RequestId,
mkRequestId,
Response(..)
) where
import Network.Endpoints
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Monad
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Serialize
import Data.UUID
import Data.UUID.V4
import Data.Word
import GHC.Generics hiding (from)
type Method = String
data RPCMessageType = Req | Rsp deriving (Eq,Show,Enum,Generic)
instance Serialize RPCMessageType
newtype RequestId = RequestId (Word32, Word32, Word32, Word32) deriving (Generic,Eq,Show)
instance Serialize (RequestId)
mkRequestId :: IO RequestId
mkRequestId = do
ruuid <- nextRandom
return $ RequestId $ toWords ruuid
data Request = Request {
requestId :: RequestId,
requestCaller :: Name,
requestMethod :: Method,
requestArgs :: Message
} deriving (Eq,Show)
instance Serialize Request where
put req = do
put Req
put $ requestId req
put $ requestCaller req
put $ requestMethod req
put $ requestArgs req
get = do
Req <- get
rid <- get
caller <- get
method <- get
args <- get
return $ Request rid caller method args
data Response = Response {
responseId :: RequestId,
responseFrom :: Name,
responseValue :: Message
} deriving (Eq,Show)
instance Serialize Response where
put rsp = do
put Rsp
put $ responseId rsp
put $ responseFrom rsp
put $ responseValue rsp
get = do
Rsp <- get
rid <- get
from <- get
val <- get
return $ Response rid from val
data CallSite = CallSite Endpoint Name
newCallSite :: Endpoint -> Name -> CallSite
newCallSite = CallSite
call :: CallSite -> Name -> Method -> Message -> IO Message
call (CallSite endpoint from) name method args = do
rid <- mkRequestId
let req = Request {requestId = rid,requestCaller = from,requestMethod = method, requestArgs = args}
sendMessage endpoint name $ encode req
selectMessage endpoint $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Response respId _ value) -> do
if respId == rid
then Just value
else Nothing
callWithTimeout :: CallSite -> Name -> Method -> Int-> Message -> IO (Maybe Message)
callWithTimeout site name method delay args = do
resultOrTimeout <- race callIt (threadDelay delay)
case resultOrTimeout of
Left value -> return $ Just value
Right _ -> return Nothing
where
callIt = call site name method args
gcall :: CallSite -> [Name] -> Method -> Message -> IO (M.Map Name Message)
gcall (CallSite endpoint from) names method args = do
rid <- mkRequestId
let req = Request {requestId = rid,requestCaller = from,requestMethod = method, requestArgs = args}
sendAll req
recvAll req M.empty
where
sendAll req = do
forM_ names $ \name -> sendMessage endpoint name $ encode req
recv req = selectMessage endpoint $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Response rid name value) -> do
if (rid == (requestId req)) && (elem name names)
then Just (name,value)
else Nothing
recvAll req results = do
(replier,result) <- recv req
let newResults = M.insert replier result results
replied = S.fromList $ M.keys newResults
expected = S.fromList names
if S.null (S.difference expected replied)
then return newResults
else recvAll req newResults
gcallWithTimeout :: CallSite -> [Name] -> Method -> Int -> Message -> IO (M.Map Name (Maybe Message))
gcallWithTimeout (CallSite endpoint from) names method delay args = do
rid <- mkRequestId
let req = Request {requestId = rid,requestCaller = from,requestMethod = method, requestArgs = args}
sendAll req
allResults <- atomically $ newTVar M.empty
responses <- race (recvAll req allResults) (threadDelay delay)
case responses of
Left results -> return $ complete results
Right _ -> do
partialResults <- atomically $ readTVar allResults
return $ complete partialResults
where
sendAll req = do
forM_ names $ \name -> sendMessage endpoint name $ encode req
recv req = selectMessage endpoint $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Response rid name value) -> do
if (rid == (requestId req)) && (elem name names)
then Just (name,value)
else Nothing
recvAll :: Request -> TVar (M.Map Name Message) -> IO (M.Map Name Message)
recvAll req allResults = do
(replier,result) <- recv req
newResults <- atomically $ do
modifyTVar allResults $ \results -> M.insert replier result results
readTVar allResults
let replied = S.fromList $ M.keys newResults
expected = S.fromList names
if S.null (S.difference expected replied)
then return newResults
else recvAll req allResults
complete :: (Serialize b) => M.Map Name b -> M.Map Name (Maybe b)
complete partial = foldl (\final name -> M.insert name (M.lookup name partial) final) M.empty names
anyCall :: CallSite -> [Name] -> Method -> Message -> IO (Message,Name)
anyCall (CallSite endpoint from) names method args = do
rid <- mkRequestId
let req = Request {requestId = rid,requestCaller = from,requestMethod = method, requestArgs = args}
sendAll req
recvAny req
where
sendAll req = do
forM_ names $ \name -> sendMessage endpoint name $ encode req
recvAny req = selectMessage endpoint $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Response rid name value) -> do
if (rid == (requestId req)) && (elem name names)
then Just (value,name)
else Nothing
type Reply b = b -> IO ()
methodSelector :: Method -> Message -> Maybe (Name,RequestId,Message)
methodSelector method msg = do
case decode msg of
Left _ -> Nothing
Right (Request rid caller rmethod args) -> do
if rmethod == method
then Just (caller,rid,args)
else Nothing
anySelector :: Message -> Maybe (Name,RequestId,Method,Message)
anySelector msg =
case decode msg of
Left _ -> Nothing
Right (Request rid caller method args) -> Just (caller,rid,method,args)
hear :: Endpoint -> Name -> Method -> IO (Message,Reply Message)
hear endpoint name method = do
(caller,rid,args) <- selectMessage endpoint $ methodSelector method
return (args, reply caller rid)
where
reply caller rid result = do
sendMessage endpoint caller $ encode $ Response rid name result
hearTimeout :: Endpoint -> Name -> Method -> Int -> IO (Maybe (Message,Reply Message))
hearTimeout endpoint name method timeout = do
req <- selectMessageTimeout endpoint timeout $ methodSelector method
case req of
Just (caller,rid,args) -> return $ Just (args, reply caller rid)
Nothing -> return Nothing
where
reply caller rid result = do
sendMessage endpoint caller $ encode $ Response rid name result
hearAll :: Endpoint -> Name -> IO (Method,Message,Reply Message)
hearAll endpoint name = do
(caller,rid,method,args) <- selectMessage endpoint anySelector
return (method,args,reply caller rid)
where
reply caller rid result =
sendMessage endpoint caller $ encode $ Response rid name result
anySelector msg =
case decode msg of
Left _ -> Nothing
Right (Request rid caller method args) -> Just (caller,rid,method,args)
hearAllTimeout :: Endpoint -> Name -> Int -> IO (Maybe (Method,Message,Reply Message))
hearAllTimeout endpoint name timeout = do
req <- selectMessageTimeout endpoint timeout anySelector
case req of
Just (caller,rid,method,args) -> return $ Just (method,args, reply caller rid)
Nothing -> return Nothing
where
reply caller rid result = do
sendMessage endpoint caller $ encode $ Response rid name result
data HandleSite = HandleSite Name (Async ())
handle :: Endpoint -> Name -> Method -> (Message -> IO Message) -> IO HandleSite
handle endpoint name method fn = do
task <- async handleCall
return $ HandleSite name task
where
handleCall = do
(args,reply) <- hear endpoint name method
result <- fn args
reply result
handleCall
handleAll :: Endpoint -> Name -> (Method -> Message -> IO Message) -> IO HandleSite
handleAll endpoint name fn = do
task <- async handleCall
return $ HandleSite name task
where
handleCall = do
(method,args,reply) <- hearAll endpoint name
result <- fn method args
reply result
handleCall
hangup :: HandleSite -> IO ()
hangup (HandleSite _ task) = do
cancel task
return ()