{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} module Network.Hadoop.Rpc ( Connection(..) , Protocol(..) , User , Method , RawRequest , RawResponse , initConnectionV9 , invokeAsync , invoke ) where import Control.Applicative ((<$>)) import Control.Concurrent (ThreadId, forkIO, newEmptyMVar, putMVar, takeMVar) import Control.Concurrent.STM import Control.Exception (SomeException(..), throwIO, handle) import Control.Monad (forever, when) import Data.ByteString (ByteString) import qualified Data.ByteString.Lazy as L import qualified Data.HashMap.Strict as H import Data.Hashable (Hashable) import Data.Maybe (fromMaybe, isNothing) import Data.Monoid ((<>)) import Data.Monoid (mempty) import Data.Text (Text) import qualified Data.Text as T import qualified Data.UUID as UUID import System.Random (randomIO) import Data.ProtocolBuffers import Data.ProtocolBuffers.Orphans () import Data.Serialize.Get import Data.Serialize.Put import qualified Data.Hadoop.Protobuf.Headers as P import Data.Hadoop.Types import qualified Network.Hadoop.Stream as S import Network.Socket (Socket) ------------------------------------------------------------------------ data Connection = Connection { cnVersion :: !Int , cnConfig :: !HadoopConfig , cnProtocol :: !Protocol , invokeRaw :: !(Method -> RawRequest -> (RawResponse -> IO ()) -> IO ()) } data Protocol = Protocol { prName :: !Text , prVersion :: !Int } deriving (Eq, Ord, Show) type Method = Text type RawRequest = ByteString type RawResponse = Either SomeException ByteString type CallId = Int newtype ClientId = ClientId { unClientId :: ByteString } ------------------------------------------------------------------------ -- hadoop-2.1.0-beta is on version 9 -- see https://issues.apache.org/jira/browse/HADOOP-8990 for differences data ConnectionState = ConnectionState { csStream :: !S.Stream , csClientId :: !ClientId , csCallId :: !(TVar CallId) , csRecvCallbacks :: !(TVar (H.HashMap CallId (RawResponse -> IO ()))) , csSendQueue :: !(TQueue (Method, RawRequest, RawResponse -> IO ())) , csFatalError :: !(TVar (Maybe SomeException)) } initConnectionV9 :: HadoopConfig -> Protocol -> Socket -> IO Connection initConnectionV9 config@HadoopConfig{..} protocol sock = do csStream <- S.mkSocketStream sock csClientId <- mkClientId S.runPut csStream $ do putByteString "hrpc" putWord8 9 -- version putWord8 0 -- rpc service class (0 = default/protobuf, 1 = built-in, 2 = writable, 3 = protobuf putWord8 0 -- auth protocol (0 = none, -33/0xDF = sasl) putMessage $ delimitedBytesL (rpcRequestHeaderProto csClientId (-3)) <> delimitedBytesL (contextProto protocol hcUser) csCallId <- newTVarIO 0 csRecvCallbacks <- newTVarIO H.empty csSendQueue <- newTQueueIO csFatalError <- newTVarIO Nothing let cs = ConnectionState{..} _ <- forkSend cs _ <- forkRecv cs return (Connection 7 config protocol (enqueue cs)) where enqueue :: ConnectionState -> Method -> RawRequest -> (RawResponse -> IO ()) -> IO () enqueue ConnectionState{..} method bs k = do merr <- atomically $ do merr <- readTVar csFatalError when (isNothing merr) $ writeTQueue csSendQueue (method, bs, k) return merr case merr of Just err -> throwIO err Nothing -> return () forkSend :: ConnectionState -> IO ThreadId forkSend cs@ConnectionState{..} = forkIO $ handle (onSocketError cs) $ forever $ do bs <- atomically $ do (method, requestBytes, k) <- readTQueue csSendQueue callId <- readTVar csCallId modifyTVar' csCallId succ modifyTVar' csRecvCallbacks (H.insert callId k) return $ delimitedBytesL (rpcRequestHeaderProto csClientId callId) <> delimitedBytesL (requestHeaderProto protocol method) <> L.fromStrict requestBytes S.runPut csStream (putMessage bs) forkRecv :: ConnectionState -> IO ThreadId forkRecv cs@ConnectionState{..} = forkIO $ handle (onSocketError cs) $ forever $ do mget <- S.maybeGet csStream $ do n <- fromIntegral <$> getWord32be -- TODO Would be nice if we didn't have to isolate here -- TODO and could stream instead. We could stream if we -- TODO were able to read the varint length prefix -- TODO ourselves and keep track of how many bytes were -- TODO remaining instead of calling `getRemaining`. isolate n $ do hdr <- decodeLengthPrefixedMessage msg <- case getField (P.rspStatus hdr) of P.Success -> Right <$> getRemaining _ -> return . Left . SomeException $ rspError hdr return (hdr, msg) case mget of Nothing -> throwIO ConnectionClosed Just (hdr, msg) -> do onResponse <- fromMaybe (return $ return ()) <$> lookupDelete csRecvCallbacks (fromIntegral $ getField $ P.rspCallId hdr) onResponse msg onSocketError :: ConnectionState -> SomeException -> IO () onSocketError ConnectionState{..} ex = do ks <- atomically $ do writeTVar csFatalError (Just ex) sks <- map (\(_,_,k) -> k) <$> unfoldM (tryReadTQueue csSendQueue) rks <- H.elems <$> readTVar csRecvCallbacks return (sks ++ rks) mapM_ (\k -> handle ignore $ k $ Left ex) ks ignore :: SomeException -> IO () ignore _ = return () mkClientId :: IO ClientId mkClientId = ClientId . L.toStrict . UUID.toByteString <$> randomIO unfoldM :: Monad m => m (Maybe a) -> m [a] unfoldM f = go [] where go xs = do m <- f case m of Nothing -> return xs Just x -> go (xs ++ [x]) lookupDelete :: (Eq k, Hashable k) => TVar (H.HashMap k v) -> k -> IO (Maybe v) lookupDelete var k = atomically $ do hm <- readTVar var writeTVar var (H.delete k hm) return (H.lookup k hm) ------------------------------------------------------------------------ contextProto :: Protocol -> User -> P.IpcConnectionContext contextProto protocol user = P.IpcConnectionContext { P.ctxProtocol = putField (Just (prName protocol)) , P.ctxUserInfo = putField (Just P.UserInformation { P.effectiveUser = putField (Just user) , P.realUser = mempty }) } rpcRequestHeaderProto :: ClientId -> CallId -> P.RpcRequestHeader rpcRequestHeaderProto clientId callId = P.RpcRequestHeader { P.reqKind = putField (Just P.ProtocolBuffer) , P.reqOp = putField (Just P.FinalPacket) , P.reqCallId = putField (fromIntegral callId) , P.reqClientId = putField (unClientId clientId) , P.reqRetryCount = putField (Just (-1)) } requestHeaderProto :: Protocol -> Method -> P.RequestHeader requestHeaderProto protocol method = P.RequestHeader { P.reqMethodName = putField method , P.reqProtocolName = putField (prName protocol) , P.reqProtocolVersion = putField (fromIntegral (prVersion protocol)) } rspError :: P.RpcResponseHeader -> RemoteError rspError rsp = RemoteError (fromMaybe "unknown error" $ getField $ P.rspExceptionClassName rsp) (fromMaybe "unknown error" $ getField $ P.rspErrorMsg rsp) putMessage :: L.ByteString -> Put putMessage body = do putWord32be (fromIntegral (L.length body)) putLazyByteString body getRemaining :: Get ByteString getRemaining = do n <- remaining getByteString n ------------------------------------------------------------------------ invoke :: (Decode b, Encode a) => Connection -> Text -> a -> IO b invoke connection method arg = do mv <- newEmptyMVar invokeAsync connection method arg (putMVar mv) e <- takeMVar mv case e of Left ex -> throwIO ex Right x -> return x invokeAsync :: (Decode b, Encode a) => Connection -> Text -> a -> (Either SomeException b -> IO ()) -> IO () invokeAsync Connection{..} method arg k = invokeRaw method (delimitedBytes arg) k' where k' (Left err) = k (Left err) k' (Right bs) = k (fromDelimitedBytes bs) delimitedBytes :: Encode a => a -> ByteString delimitedBytes = runPut . encodeLengthPrefixedMessage delimitedBytesL :: Encode a => a -> L.ByteString delimitedBytesL = L.fromStrict . delimitedBytes fromDelimitedBytes :: Decode a => ByteString -> Either SomeException a fromDelimitedBytes bs = case runGetState decodeLengthPrefixedMessage bs 0 of Left err -> decodeError (T.pack err) Right (x, "") -> Right x Right (_, _) -> decodeError "decoded response but did not consume enough bytes" where decodeError = Left . SomeException . DecodeError