{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} module Network.Hadoop.Rpc ( Connection(..) , Protocol(..) , User , Method , RawRequest , RawResponse , initConnectionV7 , invokeAsync , invoke ) where import Control.Applicative ((<$>), (<*>)) import Control.Concurrent (ThreadId, forkIO, newEmptyMVar, putMVar, takeMVar) import Control.Concurrent.STM import Control.Exception (SomeException(..), throwIO, handle) import Control.Monad (forever, when) import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as B import qualified Data.HashMap.Strict as H import Data.Hashable (Hashable) import Data.Maybe (fromMaybe, isNothing) import Data.Monoid (mempty) import Data.Text (Text) import qualified Data.Text as T import qualified Data.Text.Encoding as T import Data.ProtocolBuffers import Data.ProtocolBuffers.Orphans () import Data.Serialize.Get import Data.Serialize.Put import Data.Hadoop.Protobuf.Headers import Data.Hadoop.Types import qualified Network.Hadoop.Stream as S import Network.Socket (Socket) ------------------------------------------------------------------------ data Connection = Connection { cnVersion :: !Int , cnConfig :: !HadoopConfig , cnProtocol :: !Protocol , invokeRaw :: !(Method -> RawRequest -> (RawResponse -> IO ()) -> IO ()) } data Protocol = Protocol { prName :: !Text , prVersion :: !Int } deriving (Eq, Ord, Show) type Method = Text type RawRequest = ByteString type RawResponse = Either SomeException ByteString type CallId = Int ------------------------------------------------------------------------ -- hadoop-2.1.0-beta is on version 9 -- see https://issues.apache.org/jira/browse/HADOOP-8990 for differences data ConnectionState = ConnectionState { csStream :: !S.Stream , csCallId :: !(TVar CallId) , csRecvCallbacks :: !(TVar (H.HashMap CallId (RawResponse -> IO ()))) , csSendQueue :: !(TQueue (Method, RawRequest, RawResponse -> IO ())) , csFatalError :: !(TVar (Maybe SomeException)) } initConnectionV7 :: HadoopConfig -> Protocol -> Socket -> IO Connection initConnectionV7 config@HadoopConfig{..} protocol sock = do csStream <- S.mkSocketStream sock S.runPut csStream $ do putByteString "hrpc" putWord8 7 -- version putWord8 80 -- auth method (80 = simple, 81 = kerberos/gssapi, 82 = token/digest-md5) putWord8 0 -- ipc serialization type (0 = protobuf) let bs = runPut (encodeMessage context) putWord32be (fromIntegral (B.length bs)) putByteString bs csCallId <- newTVarIO 0 csRecvCallbacks <- newTVarIO H.empty csSendQueue <- newTQueueIO csFatalError <- newTVarIO Nothing let cs = ConnectionState{..} _ <- forkSend cs _ <- forkRecv cs return (Connection 7 config protocol (enqueue cs)) where enqueue :: ConnectionState -> Method -> RawRequest -> (RawResponse -> IO ()) -> IO () enqueue ConnectionState{..} method bs k = do merr <- atomically $ do merr <- readTVar csFatalError when (isNothing merr) $ writeTQueue csSendQueue (method, bs, k) return merr case merr of Just err -> throwIO err Nothing -> return () forkSend :: ConnectionState -> IO ThreadId forkSend cs@ConnectionState{..} = forkIO $ handle (onSocketError cs) $ forever $ do bs <- atomically $ do (method, requestBytes, k) <- readTQueue csSendQueue callId <- readTVar csCallId modifyTVar' csCallId succ modifyTVar' csRecvCallbacks (H.insert callId k) return $ runPut $ encodeLengthPrefixedMessage (requestHeaderProto callId) >> encodeLengthPrefixedMessage (requestProto method requestBytes) S.runPut csStream $ do putWord32be (fromIntegral (B.length bs)) putByteString bs forkRecv :: ConnectionState -> IO ThreadId forkRecv cs@ConnectionState{..} = forkIO $ handle (onSocketError cs) $ forever $ do hdr <- S.maybeGet csStream decodeLengthPrefixedMessage case hdr of Nothing -> throwIO ConnectionClosed Just rspHdr -> do onResponse <- fromMaybe (return $ return ()) <$> lookupDelete csRecvCallbacks (fromIntegral $ getField $ rspCallId rspHdr) case getField (rspStatus rspHdr) of Success -> S.runGet csStream getResponse >>= onResponse . Right _ -> S.runGet csStream getError >>= onResponse . Left . SomeException onSocketError :: ConnectionState -> SomeException -> IO () onSocketError ConnectionState{..} ex = do ks <- atomically $ do writeTVar csFatalError (Just ex) sks <- map (\(_,_,k) -> k) <$> unfoldM (tryReadTQueue csSendQueue) rks <- H.elems <$> readTVar csRecvCallbacks return (sks ++ rks) mapM_ (\k -> handle ignore $ k $ Left ex) ks ignore :: SomeException -> IO () ignore _ = return () context = IpcConnectionContext { ctxProtocol = putField (Just (prName protocol)) , ctxUserInfo = putField (Just UserInformation { effectiveUser = putField (Just hcUser) , realUser = mempty }) } requestHeaderProto callId = RpcRequestHeader { reqKind = putField (Just ProtocolBuffer) , reqOp = putField (Just FinalPacket) , reqCallId = putField (fromIntegral callId) } requestProto method bytes = RpcRequest { reqMethodName = putField method , reqBytes = putField (Just bytes) , reqProtocolName = putField (prName protocol) , reqProtocolVersion = putField (fromIntegral (prVersion protocol)) } unfoldM :: Monad m => m (Maybe a) -> m [a] unfoldM f = go [] where go xs = do m <- f case m of Nothing -> return xs Just x -> go (xs ++ [x]) lookupDelete :: (Eq k, Hashable k) => TVar (H.HashMap k v) -> k -> IO (Maybe v) lookupDelete var k = atomically $ do hm <- readTVar var writeTVar var (H.delete k hm) return (H.lookup k hm) ------------------------------------------------------------------------ getResponse :: Get ByteString getResponse = do n <- fromIntegral <$> getWord32be getByteString n getError :: Get RemoteError getError = RemoteError <$> getText <*> getText where getText = do n <- fromIntegral <$> getWord32be T.decodeUtf8 <$> getByteString n ------------------------------------------------------------------------ invoke :: (Decode b, Encode a) => Connection -> Text -> a -> IO b invoke connection method arg = do mv <- newEmptyMVar invokeAsync connection method arg (putMVar mv) e <- takeMVar mv case e of Left ex -> throwIO ex Right x -> return x invokeAsync :: (Decode b, Encode a) => Connection -> Text -> a -> (Either SomeException b -> IO ()) -> IO () invokeAsync Connection{..} method arg k = invokeRaw method (encodeBytes arg) k' where k' (Left err) = k (Left err) k' (Right bs) = k (decodeBytes bs) encodeBytes :: Encode a => a -> ByteString encodeBytes = runPut . encodeMessage decodeBytes :: Decode a => ByteString -> Either SomeException a decodeBytes bs = case runGetState decodeMessage bs 0 of Left err -> decodeError (T.pack err) Right (x, "") -> Right x Right (_, _) -> decodeError "decoded response but did not consume enough bytes" where decodeError = Left . SomeException . DecodeError