module Network.AMQP.Internal where
import Control.Concurrent
import Control.Monad
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put as BPut
import Data.Maybe
import Data.Text (Text)
import Data.Typeable
import Network
import System.IO
import qualified Control.Exception as CE
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as M
import qualified Data.Foldable as F
import qualified Data.IntMap as IM
import qualified Data.Sequence as Seq
import qualified Data.Text as T
import qualified Data.Text.Encoding as E
import Network.AMQP.Protocol
import Network.AMQP.Types
import Network.AMQP.Helpers
import Network.AMQP.Generated
data DeliveryMode = Persistent
| NonPersistent
deriving (Eq, Ord, Read, Show)
deliveryModeToInt :: DeliveryMode -> Octet
deliveryModeToInt NonPersistent = 1
deliveryModeToInt Persistent = 2
intToDeliveryMode :: Octet -> DeliveryMode
intToDeliveryMode 1 = NonPersistent
intToDeliveryMode 2 = Persistent
intToDeliveryMode n = error ("Unknown delivery mode int: " ++ show n)
data Message = Message {
msgBody :: BL.ByteString,
msgDeliveryMode :: Maybe DeliveryMode,
msgTimestamp :: Maybe Timestamp,
msgID :: Maybe Text,
msgContentType :: Maybe Text,
msgReplyTo :: Maybe Text,
msgCorrelationID :: Maybe Text,
msgHeaders :: Maybe FieldTable
}
deriving (Eq, Ord, Read, Show)
data Envelope = Envelope
{
envDeliveryTag :: LongLongInt,
envRedelivered :: Bool,
envExchangeName :: Text,
envRoutingKey :: Text,
envChannel :: Channel
}
data PublishError = PublishError
{
errReplyCode :: ReturnReplyCode,
errExchange :: Maybe Text,
errRoutingKey :: Text
}
deriving (Eq, Read, Show)
data ReturnReplyCode = Unroutable Text
| NoConsumers Text
| NotFound Text
deriving (Eq, Read, Show)
data Assembly = SimpleMethod MethodPayload
| ContentMethod MethodPayload ContentHeaderProperties BL.ByteString
deriving Show
readAssembly :: Chan FramePayload -> IO Assembly
readAssembly chan = do
m <- readChan chan
case m of
MethodPayload p ->
if hasContent m
then do
(props, msg) <- collectContent chan
return $ ContentMethod p props msg
else do
return $ SimpleMethod p
x -> error $ "didn't expect frame: " ++ show x
collectContent :: Chan FramePayload -> IO (ContentHeaderProperties, BL.ByteString)
collectContent chan = do
(ContentHeaderPayload _ _ bodySize props) <- readChan chan
content <- collect $ fromIntegral bodySize
return (props, BL.concat content)
where
collect x | x <= 0 = return []
collect x = do
(ContentBodyPayload payload) <- readChan chan
r <- collect (x (BL.length payload))
return $ payload : r
data Connection = Connection {
connHandle :: Handle,
connChannels :: (MVar (IM.IntMap (Channel, ThreadId))),
connMaxFrameSize :: Int,
connClosed :: MVar (Maybe String),
connClosedLock :: MVar (),
connWriteLock :: MVar (),
connClosedHandlers :: MVar [IO ()],
lastChannelID :: MVar Int
}
data ConnectionOpts = ConnectionOpts {
coServers :: ![(String, PortNumber)],
coVHost :: !Text,
coAuth :: ![SASLMechanism],
coMaxFrameSize :: !(Maybe Word32),
coHeartbeatDelay :: !(Maybe Word16),
coMaxChannel :: !(Maybe Word16)
}
data SASLMechanism = SASLMechanism {
saslName :: !Text,
saslInitialResponse :: !BS.ByteString,
saslChallengeFunc :: !(Maybe (BS.ByteString -> IO BS.ByteString))
}
connectionReceiver :: Connection -> IO ()
connectionReceiver conn = do
CE.catch (do
Frame chanID payload <- readFrame (connHandle conn)
forwardToChannel chanID payload
)
(\(e :: CE.IOException) -> do
modifyMVar_ (connClosed conn) $ const $ return $ Just $ show e
killThread =<< myThreadId
)
connectionReceiver conn
where
forwardToChannel 0 (MethodPayload Connection_close_ok) = do
modifyMVar_ (connClosed conn) $ const $ return $ Just "closed by user"
killThread =<< myThreadId
forwardToChannel 0 (MethodPayload (Connection_close _ (ShortString errorMsg) _ _)) = do
writeFrame (connHandle conn) $ Frame 0 $ MethodPayload Connection_close_ok
modifyMVar_ (connClosed conn) $ const $ return $ Just $ T.unpack errorMsg
killThread =<< myThreadId
forwardToChannel 0 payload = putStrLn $ "Got unexpected msg on channel zero: " ++ show payload
forwardToChannel chanID payload = do
withMVar (connChannels conn) $ \cs -> do
case IM.lookup (fromIntegral chanID) cs of
Just c -> writeChan (inQueue $ fst c) payload
Nothing -> putStrLn $ "ERROR: channel not open " ++ show chanID
openConnection'' :: ConnectionOpts -> IO Connection
openConnection'' connOpts = withSocketsDo $ do
handle <- connect $ coServers connOpts
maxFrameSize <- CE.handle (\(_ :: CE.IOException) -> CE.throwIO $ ConnectionClosedException "Handshake failed. Please check the RabbitMQ logs for more information") $ do
BL.hPut handle $ BPut.runPut $ do
BPut.putByteString $ BC.pack "AMQP"
BPut.putWord8 1
BPut.putWord8 1 --TCP/IP
BPut.putWord8 0
BPut.putWord8 9
hFlush handle
Frame 0 (MethodPayload (Connection_start _ _ _ (LongString serverMechanisms) _)) <- readFrame handle
selectedSASL <- selectSASLMechanism handle serverMechanisms
writeFrame handle $ start_ok selectedSASL
Frame 0 (MethodPayload (Connection_tune _ frame_max _)) <- handleSecureUntilTune handle selectedSASL
let maxFrameSize = chooseMin frame_max $ coMaxFrameSize connOpts
writeFrame handle (Frame 0 (MethodPayload
(Connection_tune_ok 0 maxFrameSize 0)
))
writeFrame handle open
Frame 0 (MethodPayload (Connection_open_ok _)) <- readFrame handle
return maxFrameSize
cChannels <- newMVar IM.empty
lastChanID <- newMVar 0
cClosed <- newMVar Nothing
writeLock <- newMVar ()
ccl <- newEmptyMVar
cClosedHandlers <- newMVar []
let conn = Connection handle cChannels (fromIntegral maxFrameSize) cClosed ccl writeLock cClosedHandlers lastChanID
void $ forkIO $ CE.finally (connectionReceiver conn) $ do
CE.catch (hClose handle) (\(_ :: CE.SomeException) -> return ())
modifyMVar_ cClosed $ return . Just . maybe "unknown reason" id
modifyMVar_ cChannels $ \x -> do
mapM_ (killThread . snd) $ IM.elems x
return IM.empty
void $ tryPutMVar ccl ()
withMVar cClosedHandlers sequence
return conn
where
connect ((host, port) : rest) = do
result <- CE.try (connectTo host $ PortNumber port)
either
(\(ex :: CE.SomeException) -> do
putStrLn $ "Error connecting to "++show (host, port)++": "++show ex
connect rest)
(return)
result
connect [] = CE.throwIO $ ConnectionClosedException $ "Could not connect to any of the provided brokers: " ++ show (coServers connOpts)
selectSASLMechanism handle serverMechanisms =
let serverSaslList = T.split (== ' ') $ E.decodeUtf8 serverMechanisms
clientMechanisms = coAuth connOpts
clientSaslList = map saslName clientMechanisms
maybeSasl = F.find (\(SASLMechanism name _ _) -> elem name serverSaslList) clientMechanisms
in abortIfNothing maybeSasl handle
("None of the provided SASL mechanisms "++show clientSaslList++" is supported by the server "++show serverSaslList++".")
start_ok sasl = (Frame 0 (MethodPayload (Connection_start_ok (FieldTable M.empty)
(ShortString $ saslName sasl)
(LongString $ saslInitialResponse sasl)
(ShortString "en_US")) ))
handleSecureUntilTune handle sasl = do
tuneOrSecure <- readFrame handle
case tuneOrSecure of
Frame 0 (MethodPayload (Connection_secure (LongString challenge))) -> do
processChallenge <- abortIfNothing (saslChallengeFunc sasl)
handle $ "The server provided a challenge, but the selected SASL mechanism "++show (saslName sasl)++" is not equipped with a challenge processing function."
challengeResponse <- processChallenge challenge
writeFrame handle (Frame 0 (MethodPayload (Connection_secure_ok (LongString challengeResponse))))
handleSecureUntilTune handle sasl
tune@(Frame 0 (MethodPayload (Connection_tune _ _ _))) -> return tune
x -> error $ "handleSecureUntilTune fail. received message: "++show x
open = (Frame 0 (MethodPayload (Connection_open
(ShortString $ coVHost connOpts)
(ShortString $ T.pack "")
True)))
abortHandshake handle msg = do
hClose handle
CE.throwIO $ ConnectionClosedException msg
abortIfNothing m handle msg = case m of
Nothing -> abortHandshake handle msg
Just a -> return a
closeConnection :: Connection -> IO ()
closeConnection c = do
CE.catch (
withMVar (connWriteLock c) $ \_ -> writeFrame (connHandle c) $ (Frame 0 (MethodPayload (Connection_close
0
(ShortString "")
0
0
)))
)
(\ (_ :: CE.IOException) ->
return ()
)
readMVar $ connClosedLock c
return ()
addConnectionClosedHandler :: Connection -> Bool -> IO () -> IO ()
addConnectionClosedHandler conn ifClosed handler = do
withMVar (connClosed conn) $ \cc ->
case cc of
Just _ | ifClosed == True -> handler
_ -> modifyMVar_ (connClosedHandlers conn) $ \old -> return $ handler:old
readFrame :: Handle -> IO Frame
readFrame handle = do
dat <- BL.hGet handle 7
when (BL.null dat) $ CE.throwIO $ userError "connection not open"
let len = fromIntegral $ peekFrameSize dat
dat' <- BL.hGet handle (len+1)
when (BL.null dat') $ CE.throwIO $ userError "connection not open"
let ret = runGetOrFail get (BL.append dat dat')
case ret of
Left (_, _, errMsg) -> error $ "readFrame fail: " ++ errMsg
Right (_, consumedBytes, _) | consumedBytes /= fromIntegral (len+8) ->
error $ "readFrame: parser should read " ++ show (len+8) ++ " bytes; but read " ++ show consumedBytes
Right (_, _, frame) -> return frame
writeFrame :: Handle -> Frame -> IO ()
writeFrame handle f = do
BL.hPut handle . runPut . put $ f
hFlush handle
data Channel = Channel {
connection :: Connection,
inQueue :: Chan FramePayload,
outstandingResponses :: MVar (Seq.Seq (MVar Assembly)),
channelID :: Word16,
lastConsumerTag :: MVar Int,
chanActive :: Lock,
chanClosed :: MVar (Maybe String),
consumers :: MVar (M.Map Text ((Message, Envelope) -> IO ())),
returnListeners :: MVar ([(Message, PublishError) -> IO ()])
}
msgFromContentHeaderProperties :: ContentHeaderProperties -> BL.ByteString -> Message
msgFromContentHeaderProperties (CHBasic content_type _ headers delivery_mode _ correlation_id reply_to _ message_id timestamp _ _ _ _) body =
let msgId = fromShortString message_id
contentType = fromShortString content_type
replyTo = fromShortString reply_to
correlationID = fromShortString correlation_id
in Message body (fmap intToDeliveryMode delivery_mode) timestamp msgId contentType replyTo correlationID headers
where
fromShortString (Just (ShortString s)) = Just s
fromShortString _ = Nothing
msgFromContentHeaderProperties c _ = error ("Unknown content header properties: " ++ show c)
channelReceiver :: Channel -> IO ()
channelReceiver chan = do
p <- readAssembly $ inQueue chan
if isResponse p
then do
action <- modifyMVar (outstandingResponses chan) $ \val -> do
case Seq.viewl val of
x Seq.:< rest -> do
return (rest, putMVar x p)
Seq.EmptyL -> do
return (val, CE.throwIO $ userError "got response, but have no corresponding request")
action
else handleAsync p
channelReceiver chan
where
isResponse :: Assembly -> Bool
isResponse (ContentMethod (Basic_deliver _ _ _ _ _) _ _) = False
isResponse (ContentMethod (Basic_return _ _ _ _) _ _) = False
isResponse (SimpleMethod (Channel_flow _)) = False
isResponse (SimpleMethod (Channel_close _ _ _ _)) = False
isResponse _ = True
handleAsync (ContentMethod (Basic_deliver (ShortString consumerTag) deliveryTag redelivered (ShortString exchange)
(ShortString routingKey))
properties body) =
withMVar (consumers chan) (\s -> do
case M.lookup consumerTag s of
Just subscriber -> do
let msg = msgFromContentHeaderProperties properties body
let env = Envelope {envDeliveryTag = deliveryTag, envRedelivered = redelivered,
envExchangeName = exchange, envRoutingKey = routingKey, envChannel = chan}
CE.catch (subscriber (msg, env))
(\(e::CE.SomeException) -> putStrLn $ "AMQP callback threw exception: " ++ show e)
Nothing ->
return ()
)
handleAsync (SimpleMethod (Channel_close _ (ShortString errorMsg) _ _)) = do
closeChannel' chan errorMsg
killThread =<< myThreadId
handleAsync (SimpleMethod (Channel_flow active)) = do
if active
then openLock $ chanActive chan
else closeLock $ chanActive chan
return ()
--Basic.return
handleAsync (ContentMethod basicReturn@(Basic_return _ _ _ _) props body) = do
let msg = msgFromContentHeaderProperties props body
pubError = basicReturnToPublishError basicReturn
withMVar (returnListeners chan) $ \listeners ->
forM_ listeners $ \l -> CE.catch (l (msg, pubError)) $ \(ex :: CE.SomeException) ->
putStrLn $ "return listener on channel ["++(show $ channelID chan)++"] handling error ["++show pubError++"] threw exception: "++show ex
handleAsync m = error ("Unknown method: " ++ show m)
basicReturnToPublishError (Basic_return code (ShortString errText) (ShortString exchange) (ShortString routingKey)) =
let replyError = case code of
312 -> Unroutable errText
313 -> NoConsumers errText
404 -> NotFound errText
num -> error $ "unexpected return error code: " ++ (show num)
pubError = PublishError replyError (Just exchange) routingKey
in pubError
basicReturnToPublishError x = error $ "basicReturnToPublishError fail: "++show x
addReturnListener :: Channel -> ((Message, PublishError) -> IO ()) -> IO ()
addReturnListener chan listener = do
modifyMVar_ (returnListeners chan) $ \listeners -> return $ listener:listeners
closeChannel' :: Channel -> Text -> IO ()
closeChannel' c reason = do
modifyMVar_ (connChannels $ connection c) $ \old -> return $ IM.delete (fromIntegral $ channelID c) old
modifyMVar_ (chanClosed c) $ \x -> do
if isNothing x
then do
void $ killLock $ chanActive c
killOutstandingResponses $ outstandingResponses c
return $ Just $ maybe (T.unpack reason) id x
else return x
where
killOutstandingResponses :: MVar (Seq.Seq (MVar a)) -> IO ()
killOutstandingResponses outResps = do
modifyMVar_ outResps $ \val -> do
F.mapM_ (\x -> tryPutMVar x $ error "channel closed") val
return undefined
openChannel :: Connection -> IO Channel
openChannel c = do
newInQueue <- newChan
outRes <- newMVar Seq.empty
lastConsTag <- newMVar 0
ca <- newLock
closed <- newMVar Nothing
conss <- newMVar M.empty
listeners <- newMVar []
newChannelID <- modifyMVar (lastChannelID c) $ \x -> return (x+1, x+1)
let newChannel = Channel c newInQueue outRes (fromIntegral newChannelID) lastConsTag ca closed conss listeners
thrID <- forkIO $ CE.finally (channelReceiver newChannel)
(closeChannel' newChannel "closed")
modifyMVar_ (connChannels c) (return . IM.insert newChannelID (newChannel, thrID))
(SimpleMethod (Channel_open_ok _)) <- request newChannel (SimpleMethod (Channel_open (ShortString "")))
return newChannel
writeFrames :: Channel -> [FramePayload] -> IO ()
writeFrames chan payloads =
let conn = connection chan in
withMVar (connChannels conn) $ \chans ->
if IM.member (fromIntegral $ channelID chan) chans
then
CE.catch
(withMVar (connWriteLock conn) $ \_ ->
mapM_ (\payload -> writeFrame (connHandle conn) (Frame (channelID chan) payload)) payloads)
( \(_ :: CE.IOException) -> do
CE.throwIO $ userError "connection not open"
)
else do
CE.throwIO $ userError "channel not open"
writeAssembly' :: Channel -> Assembly -> IO ()
writeAssembly' chan (ContentMethod m properties msg) = do
waitLock $ chanActive chan
let !toWrite =
[(MethodPayload m),
(ContentHeaderPayload
(getClassIDOf properties) --classID
0
(fromIntegral $ BL.length msg) --bodySize
properties)] ++
(if BL.length msg > 0
then do
map ContentBodyPayload
(splitLen msg $ (fromIntegral $ connMaxFrameSize $ connection chan) 8)
else []
)
writeFrames chan toWrite
where
splitLen str len | BL.length str > len = (BL.take len str):(splitLen (BL.drop len str) len)
splitLen str _ = [str]
writeAssembly' chan (SimpleMethod m) = writeFrames chan [MethodPayload m]
writeAssembly :: Channel -> Assembly -> IO ()
writeAssembly chan m =
CE.catches
(writeAssembly' chan m)
[CE.Handler (\ (_ :: AMQPException) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.ErrorCall) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.IOException) -> throwMostRelevantAMQPException chan)]
request :: Channel -> Assembly -> IO Assembly
request chan m = do
res <- newEmptyMVar
CE.catches (do
withMVar (chanClosed chan) $ \cc -> do
if isNothing cc
then do
modifyMVar_ (outstandingResponses chan) $ \val -> return $! val Seq.|> res
writeAssembly' chan m
else CE.throwIO $ userError "closed"
!r <- takeMVar res
return r
)
[CE.Handler (\ (_ :: AMQPException) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.ErrorCall) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.IOException) -> throwMostRelevantAMQPException chan)]
throwMostRelevantAMQPException :: Channel -> IO a
throwMostRelevantAMQPException chan = do
cc <- readMVar $ connClosed $ connection chan
case cc of
Just r -> CE.throwIO $ ConnectionClosedException r
Nothing -> do
chc <- readMVar $ chanClosed chan
case chc of
Just r -> CE.throwIO $ ChannelClosedException r
Nothing -> CE.throwIO $ ConnectionClosedException "unknown reason"
data AMQPException =
ChannelClosedException String
| ConnectionClosedException String
deriving (Typeable, Show, Ord, Eq)
instance CE.Exception AMQPException