module Network.RPC (
Method,
newCallSite,
CallSite,
call,
callWithTimeout,
gcall,
gcallWithTimeout,
hear,
hearTimeout,
Reply,
HandleSite,
handle,
hangup
) where
import Network.Endpoints
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Monad
import qualified Data.Map as M
import qualified Data.Set as S
import Data.Serialize
import Data.UUID
import Data.UUID.V4
import Data.Word
type Method = String
type RequestId = (Word32, Word32, Word32, Word32)
data Request a = (Serialize a) => Request {
requestId :: RequestId,
requestCaller :: Name,
requestMethod :: Method,
requestArgs :: a
}
instance (Serialize a) => Serialize (Request a) where
put req = do
put $ requestId req
put $ requestCaller req
put $ requestMethod req
put $ requestArgs req
get = do
rid <- get
caller <- get
method <- get
args <- get
return $ Request rid caller method args
data Response b = (Serialize b) => Response {
responseId :: RequestId,
responseFrom :: Name,
responseValue :: b
}
instance (Serialize b) => Serialize (Response b) where
put rsp = do
put $ responseId rsp
put $ responseFrom rsp
put $ responseValue rsp
get = do
rid <- get
from <- get
val <- get
return $ Response rid from val
data CallSite = CallSite Endpoint Name
newCallSite :: Endpoint -> Name -> CallSite
newCallSite = CallSite
call :: (Serialize a, Serialize b) => CallSite -> Name -> Method -> a -> IO b
call (CallSite endpoint from) name method args = do
ruuid <- nextRandom
let req = Request {requestId = toWords ruuid,requestCaller = from,requestMethod = method, requestArgs = args}
sendMessage_ endpoint name $ encode req
selectMessage endpoint $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Response rid _ value) -> do
if rid == (requestId req)
then Just value
else Nothing
callWithTimeout :: (Serialize a, Serialize b) => CallSite -> Name -> Method -> Int-> a -> IO (Maybe b)
callWithTimeout site name method delay args = do
resultOrTimeout <- race callIt (threadDelay delay)
case resultOrTimeout of
Left value -> return $ Just value
Right _ -> return Nothing
where
callIt = call site name method args
gcall :: (Serialize a, Serialize b) => CallSite -> [Name] -> Method -> a -> IO (M.Map Name b)
gcall (CallSite endpoint from) names method args = do
ruuid <- nextRandom
let req = Request {requestId = toWords ruuid,requestCaller = from,requestMethod = method, requestArgs = args}
sendAll req
recvAll req M.empty
where
sendAll req = do
forM_ names $ \name -> sendMessage_ endpoint name $ encode req
recv req = selectMessage endpoint $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Response rid name value) -> do
if (rid == (requestId req)) && (elem name names)
then Just (name,value)
else Nothing
recvAll req results = do
(replier,result) <- recv req
let newResults = M.insert replier result results
replied = S.fromList $ M.keys newResults
expected = S.fromList names
if S.null (S.difference expected replied)
then return newResults
else recvAll req newResults
gcallWithTimeout :: (Serialize a, Serialize b) => CallSite -> [Name] -> Method -> Int -> a -> IO (M.Map Name (Maybe b))
gcallWithTimeout (CallSite endpoint from) names method delay args = do
ruuid <- nextRandom
let req = Request {requestId = toWords ruuid,requestCaller = from,requestMethod = method, requestArgs = args}
sendAll req
allResults <- atomically $ newTVar M.empty
responses <- race (recvAll req allResults) (threadDelay delay)
case responses of
Left results -> return $ complete results
Right _ -> do
partialResults <- atomically $ readTVar allResults
return $ complete partialResults
where
sendAll req = do
forM_ names $ \name -> sendMessage_ endpoint name $ encode req
recv req = selectMessage endpoint $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Response rid name value) -> do
if (rid == (requestId req)) && (elem name names)
then Just (name,value)
else Nothing
recvAll :: (Serialize b) => Request a -> TVar (M.Map Name b) -> IO (M.Map Name b)
recvAll req allResults = do
(replier,result) <- recv req
newResults <- atomically $ do
modifyTVar allResults $ \results -> M.insert replier result results
readTVar allResults
let replied = S.fromList $ M.keys newResults
expected = S.fromList names
if S.null (S.difference expected replied)
then return newResults
else recvAll req allResults
complete :: (Serialize b) => M.Map Name b -> M.Map Name (Maybe b)
complete partial = foldl (\final name -> M.insert name (M.lookup name partial) final) M.empty names
type Reply b = b -> IO ()
hear :: (Serialize a,Serialize b) => Endpoint -> Name -> Method -> IO (a,Reply b)
hear endpoint name method = do
(caller,rid,args) <- selectMessage endpoint $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Request rid caller rmethod args) -> do
if rmethod == method
then Just (caller,rid,args)
else Nothing
return (args, reply caller rid)
where
reply caller rid result = do
sendMessage_ endpoint caller $ encode $ Response rid name result
hearTimeout :: (Serialize a,Serialize b) => Endpoint -> Name -> Method -> Int -> IO (Maybe (a,Reply b))
hearTimeout endpoint name method timeout = do
req <- selectMessageTimeout endpoint timeout $ \msg -> do
case decode msg of
Left _ -> Nothing
Right (Request rid caller rmethod args) -> do
if rmethod == method
then Just (caller,rid,args)
else Nothing
case req of
Just (caller,rid,args) -> return $ Just (args, reply caller rid)
Nothing -> return Nothing
where
reply caller rid result = do
sendMessage_ endpoint caller $ encode $ Response rid name result
data HandleSite = HandleSite Name (Async ())
handle :: (Serialize a, Serialize b) => Endpoint -> Name -> Method -> (a -> IO b) -> IO HandleSite
handle endpoint name method fn = do
task <- async $ handleCall
return $ HandleSite name task
where
handleCall = do
(args,reply) <- hear endpoint name method
result <- fn args
reply result
handleCall
hangup :: HandleSite -> IO ()
hangup (HandleSite _ task) = do
cancel task
return ()