module Network.AMQP.Internal where
import Control.Concurrent
import Control.Monad
import Data.Binary
import Data.Binary.Get
import Data.Binary.Put as BPut
import Data.Int (Int64)
import Data.Maybe
import Data.Text (Text)
import Data.Typeable
import Network
import qualified Control.Exception as CE
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as M
import qualified Data.Foldable as F
import qualified Data.IntMap as IM
import qualified Data.Sequence as Seq
import qualified Data.Text as T
import qualified Data.Text.Encoding as E
import qualified Network.Connection as Conn
import Network.AMQP.Protocol
import Network.AMQP.Types
import Network.AMQP.Helpers
import Network.AMQP.Generated
import Network.AMQP.ChannelAllocator
data DeliveryMode = Persistent
| NonPersistent
deriving (Eq, Ord, Read, Show)
deliveryModeToInt :: DeliveryMode -> Octet
deliveryModeToInt NonPersistent = 1
deliveryModeToInt Persistent = 2
intToDeliveryMode :: Octet -> DeliveryMode
intToDeliveryMode 1 = NonPersistent
intToDeliveryMode 2 = Persistent
intToDeliveryMode n = error ("Unknown delivery mode int: " ++ show n)
data Message = Message {
msgBody :: BL.ByteString,
msgDeliveryMode :: Maybe DeliveryMode,
msgTimestamp :: Maybe Timestamp,
msgID :: Maybe Text,
msgContentType :: Maybe Text,
msgReplyTo :: Maybe Text,
msgCorrelationID :: Maybe Text,
msgHeaders :: Maybe FieldTable
}
deriving (Eq, Ord, Read, Show)
data Envelope = Envelope
{
envDeliveryTag :: LongLongInt,
envRedelivered :: Bool,
envExchangeName :: Text,
envRoutingKey :: Text,
envChannel :: Channel
}
data PublishError = PublishError
{
errReplyCode :: ReturnReplyCode,
errExchange :: Maybe Text,
errRoutingKey :: Text
}
deriving (Eq, Read, Show)
data ReturnReplyCode = Unroutable Text
| NoConsumers Text
| NotFound Text
deriving (Eq, Read, Show)
data Assembly = SimpleMethod MethodPayload
| ContentMethod MethodPayload ContentHeaderProperties BL.ByteString
deriving Show
readAssembly :: Chan FramePayload -> IO Assembly
readAssembly chan = do
m <- readChan chan
case m of
MethodPayload p ->
if hasContent m
then do
(props, msg) <- collectContent chan
return $ ContentMethod p props msg
else do
return $ SimpleMethod p
x -> error $ "didn't expect frame: " ++ show x
collectContent :: Chan FramePayload -> IO (ContentHeaderProperties, BL.ByteString)
collectContent chan = do
(ContentHeaderPayload _ _ bodySize props) <- readChan chan
content <- collect $ fromIntegral bodySize
return (props, BL.concat content)
where
collect x | x <= 0 = return []
collect x = do
(ContentBodyPayload payload) <- readChan chan
r <- collect (x (BL.length payload))
return $ payload : r
data Connection = Connection {
connHandle :: Conn.Connection,
connChanAllocator :: ChannelAllocator,
connChannels :: MVar (IM.IntMap (Channel, ThreadId)),
connMaxFrameSize :: Int,
connClosed :: MVar (Maybe String),
connClosedLock :: MVar (),
connWriteLock :: MVar (),
connClosedHandlers :: MVar [IO ()],
connLastReceived :: MVar Int64,
connLastSent :: MVar Int64
}
data ConnectionOpts = ConnectionOpts {
coServers :: ![(String, PortNumber)],
coVHost :: !Text,
coAuth :: ![SASLMechanism],
coMaxFrameSize :: !(Maybe Word32),
coHeartbeatDelay :: !(Maybe Word16),
coMaxChannel :: !(Maybe Word16),
coTLSSettings :: Maybe TLSSettings
}
data TLSSettings =
TLSTrusted
| TLSUntrusted
connectionTLSSettings :: TLSSettings -> Maybe Conn.TLSSettings
connectionTLSSettings tlsSettings =
Just $ case tlsSettings of
TLSTrusted -> Conn.TLSSettingsSimple False False False
TLSUntrusted -> Conn.TLSSettingsSimple True False False
data SASLMechanism = SASLMechanism {
saslName :: !Text,
saslInitialResponse :: !BS.ByteString,
saslChallengeFunc :: !(Maybe (BS.ByteString -> IO BS.ByteString))
}
connectionReceiver :: Connection -> IO ()
connectionReceiver conn = do
CE.catch (do
Frame chanID payload <- readFrame (connHandle conn)
updateLastReceived conn
forwardToChannel chanID payload
)
(\(e :: CE.IOException) -> myThreadId >>= killConnection conn (show e))
connectionReceiver conn
where
forwardToChannel 0 (MethodPayload Connection_close_ok) = myThreadId >>= killConnection conn "closed by user"
forwardToChannel 0 (MethodPayload (Connection_close _ (ShortString errorMsg) _ _)) = do
writeFrame (connHandle conn) $ Frame 0 $ MethodPayload Connection_close_ok
myThreadId >>= killConnection conn (T.unpack errorMsg)
forwardToChannel 0 HeartbeatPayload = return ()
forwardToChannel 0 payload = putStrLn $ "Got unexpected msg on channel zero: " ++ show payload
forwardToChannel chanID payload = do
withMVar (connChannels conn) $ \cs -> do
case IM.lookup (fromIntegral chanID) cs of
Just c -> writeChan (inQueue $ fst c) payload
Nothing -> putStrLn $ "ERROR: channel not open " ++ show chanID
openConnection'' :: ConnectionOpts -> IO Connection
openConnection'' connOpts = withSocketsDo $ do
handle <- connect $ coServers connOpts
(maxFrameSize, heartbeatTimeout) <- CE.handle (\(_ :: CE.IOException) -> CE.throwIO $ ConnectionClosedException "Handshake failed. Please check the RabbitMQ logs for more information") $ do
Conn.connectionPut handle $ BS.append (BC.pack "AMQP")
(BS.pack [
1
, 1 --TCP/IP
, 0
, 9
])
Frame 0 (MethodPayload (Connection_start _ _ _ (LongString serverMechanisms) _)) <- readFrame handle
selectedSASL <- selectSASLMechanism handle serverMechanisms
writeFrame handle $ start_ok selectedSASL
Frame 0 (MethodPayload (Connection_tune _ frame_max sendHeartbeat)) <- handleSecureUntilTune handle selectedSASL
let maxFrameSize = chooseMin frame_max $ coMaxFrameSize connOpts
finalHeartbeatSec = fromMaybe sendHeartbeat (coHeartbeatDelay connOpts)
heartbeatTimeout = mfilter (/=0) (Just finalHeartbeatSec)
writeFrame handle (Frame 0 (MethodPayload
(Connection_tune_ok 0 maxFrameSize finalHeartbeatSec)
))
writeFrame handle open
Frame 0 (MethodPayload (Connection_open_ok _)) <- readFrame handle
return (maxFrameSize, heartbeatTimeout)
cChannels <- newMVar IM.empty
cClosed <- newMVar Nothing
cChanAllocator <- newChannelAllocator
_ <- allocateChannel cChanAllocator
writeLock <- newMVar ()
ccl <- newEmptyMVar
cClosedHandlers <- newMVar []
cLastReceived <- getTimestamp >>= newMVar
cLastSent <- getTimestamp >>= newMVar
let conn = Connection handle cChanAllocator cChannels (fromIntegral maxFrameSize) cClosed ccl writeLock cClosedHandlers cLastReceived cLastSent
connThread <- forkIO $ CE.finally (connectionReceiver conn) $ do
CE.catch (Conn.connectionClose handle) (\(_ :: CE.SomeException) -> return ())
modifyMVar_ cClosed $ return . Just . maybe "unknown reason" id
modifyMVar_ cChannels $ \x -> do
mapM_ (killThread . snd) $ IM.elems x
return IM.empty
void $ tryPutMVar ccl ()
withMVar cClosedHandlers sequence
case heartbeatTimeout of
Nothing -> return ()
Just timeout -> do
heartbeatThread <- watchHeartbeats conn (fromIntegral timeout) connThread
addConnectionClosedHandler conn True (killThread heartbeatThread)
return conn
where
connect ((host, port) : rest) = do
ctx <- Conn.initConnectionContext
result <- CE.try (Conn.connectTo ctx $ Conn.ConnectionParams
{ Conn.connectionHostname = host
, Conn.connectionPort = port
, Conn.connectionUseSecure = tlsSettings
, Conn.connectionUseSocks = Nothing
})
either
(\(ex :: CE.SomeException) -> do
putStrLn $ "Error connecting to "++show (host, port)++": "++show ex
connect rest)
(return)
result
connect [] = CE.throwIO $ ConnectionClosedException $ "Could not connect to any of the provided brokers: " ++ show (coServers connOpts)
tlsSettings = maybe Nothing connectionTLSSettings (coTLSSettings connOpts)
selectSASLMechanism handle serverMechanisms =
let serverSaslList = T.split (== ' ') $ E.decodeUtf8 serverMechanisms
clientMechanisms = coAuth connOpts
clientSaslList = map saslName clientMechanisms
maybeSasl = F.find (\(SASLMechanism name _ _) -> elem name serverSaslList) clientMechanisms
in abortIfNothing maybeSasl handle
("None of the provided SASL mechanisms "++show clientSaslList++" is supported by the server "++show serverSaslList++".")
start_ok sasl = (Frame 0 (MethodPayload (Connection_start_ok (FieldTable M.empty)
(ShortString $ saslName sasl)
(LongString $ saslInitialResponse sasl)
(ShortString "en_US")) ))
handleSecureUntilTune handle sasl = do
tuneOrSecure <- readFrame handle
case tuneOrSecure of
Frame 0 (MethodPayload (Connection_secure (LongString challenge))) -> do
processChallenge <- abortIfNothing (saslChallengeFunc sasl)
handle $ "The server provided a challenge, but the selected SASL mechanism "++show (saslName sasl)++" is not equipped with a challenge processing function."
challengeResponse <- processChallenge challenge
writeFrame handle (Frame 0 (MethodPayload (Connection_secure_ok (LongString challengeResponse))))
handleSecureUntilTune handle sasl
tune@(Frame 0 (MethodPayload (Connection_tune _ _ _))) -> return tune
x -> error $ "handleSecureUntilTune fail. received message: "++show x
open = (Frame 0 (MethodPayload (Connection_open
(ShortString $ coVHost connOpts)
(ShortString $ T.pack "")
True)))
abortHandshake handle msg = do
Conn.connectionClose handle
CE.throwIO $ ConnectionClosedException msg
abortIfNothing m handle msg = case m of
Nothing -> abortHandshake handle msg
Just a -> return a
watchHeartbeats :: Connection -> Int -> ThreadId -> IO ThreadId
watchHeartbeats conn timeout connThread = scheduleAtFixedRate rate $ do
checkSendTimeout
checkReceiveTimeout
where
rate = timeout * 1000 * 250
receiveTimeout = (fromIntegral rate) * 4 * 2
sendTimeout = (fromIntegral rate) * 2
checkReceiveTimeout = check (connLastReceived conn) receiveTimeout
(killConnection conn "killed connection after missing 2 consecutive heartbeats" connThread)
checkSendTimeout = check (connLastSent conn) sendTimeout
(writeFrame (connHandle conn) (Frame 0 HeartbeatPayload))
check var timeout_µs action = withMVar var $ \lastFrameTime -> do
time <- getTimestamp
when (time >= lastFrameTime + timeout_µs) $ do
action
updateLastSent :: Connection -> IO ()
updateLastSent conn = modifyMVar_ (connLastSent conn) (const getTimestamp)
updateLastReceived :: Connection -> IO ()
updateLastReceived conn = modifyMVar_ (connLastReceived conn) (const getTimestamp)
killConnection :: Connection -> String -> ThreadId -> IO ()
killConnection conn msg connThread = do
modifyMVar_ (connClosed conn) $ const $ return $ Just msg
killThread connThread
closeConnection :: Connection -> IO ()
closeConnection c = do
CE.catch (
withMVar (connWriteLock c) $ \_ -> writeFrame (connHandle c) $ (Frame 0 (MethodPayload (Connection_close
0
(ShortString "")
0
0
)))
)
(\ (_ :: CE.IOException) ->
return ()
)
readMVar $ connClosedLock c
return ()
addConnectionClosedHandler :: Connection -> Bool -> IO () -> IO ()
addConnectionClosedHandler conn ifClosed handler = do
withMVar (connClosed conn) $ \cc ->
case cc of
Just _ | ifClosed == True -> handler
_ -> modifyMVar_ (connClosedHandlers conn) $ \old -> return $ handler:old
readFrame :: Conn.Connection -> IO Frame
readFrame handle = do
strictDat <- connectionGetExact handle 7
let dat = toLazy strictDat
when (BL.null dat) $ CE.throwIO $ userError "connection not open"
let len = fromIntegral $ peekFrameSize dat
strictDat' <- connectionGetExact handle (len+1)
let dat' = toLazy strictDat'
when (BL.null dat') $ CE.throwIO $ userError "connection not open"
#if MIN_VERSION_binary(0, 7, 0)
let ret = runGetOrFail get (BL.append dat dat')
case ret of
Left (_, _, errMsg) -> error $ "readFrame fail: " ++ errMsg
Right (_, consumedBytes, _) | consumedBytes /= fromIntegral (len+8) ->
error $ "readFrame: parser should read " ++ show (len+8) ++ " bytes; but read " ++ show consumedBytes
Right (_, _, frame) -> return frame
#else
let (frame, _, consumedBytes) = runGetState get (BL.append dat dat') 0
if consumedBytes /= fromIntegral (len+8)
then error $ "readFrameSock: parser should read "++show (len+8)++" bytes; but read "++show consumedBytes
else return ()
return frame
#endif
connectionGetExact :: Conn.Connection -> Int -> IO BS.ByteString
connectionGetExact conn x = loop BS.empty 0
where loop bs y
| y == x = return bs
| otherwise = do
next <- Conn.connectionGet conn (x y)
loop (BS.append bs next) (y + (BS.length next))
writeFrame :: Conn.Connection -> Frame -> IO ()
writeFrame handle f = do
Conn.connectionPut handle . toStrict . runPut . put $ f
data Channel = Channel {
connection :: Connection,
inQueue :: Chan FramePayload,
outstandingResponses :: MVar (Seq.Seq (MVar Assembly)),
channelID :: Word16,
lastConsumerTag :: MVar Int,
chanActive :: Lock,
chanClosed :: MVar (Maybe String),
consumers :: MVar (M.Map Text ((Message, Envelope) -> IO ())),
returnListeners :: MVar ([(Message, PublishError) -> IO ()])
}
msgFromContentHeaderProperties :: ContentHeaderProperties -> BL.ByteString -> Message
msgFromContentHeaderProperties (CHBasic content_type _ headers delivery_mode _ correlation_id reply_to _ message_id timestamp _ _ _ _) body =
let msgId = fromShortString message_id
contentType = fromShortString content_type
replyTo = fromShortString reply_to
correlationID = fromShortString correlation_id
in Message body (fmap intToDeliveryMode delivery_mode) timestamp msgId contentType replyTo correlationID headers
where
fromShortString (Just (ShortString s)) = Just s
fromShortString _ = Nothing
msgFromContentHeaderProperties c _ = error ("Unknown content header properties: " ++ show c)
channelReceiver :: Channel -> IO ()
channelReceiver chan = do
p <- readAssembly $ inQueue chan
if isResponse p
then do
action <- modifyMVar (outstandingResponses chan) $ \val -> do
case Seq.viewl val of
x Seq.:< rest -> do
return (rest, putMVar x p)
Seq.EmptyL -> do
return (val, CE.throwIO $ userError "got response, but have no corresponding request")
action
else handleAsync p
channelReceiver chan
where
isResponse :: Assembly -> Bool
isResponse (ContentMethod (Basic_deliver _ _ _ _ _) _ _) = False
isResponse (ContentMethod (Basic_return _ _ _ _) _ _) = False
isResponse (SimpleMethod (Channel_flow _)) = False
isResponse (SimpleMethod (Channel_close _ _ _ _)) = False
isResponse _ = True
handleAsync (ContentMethod (Basic_deliver (ShortString consumerTag) deliveryTag redelivered (ShortString exchange)
(ShortString routingKey))
properties body) =
withMVar (consumers chan) (\s -> do
case M.lookup consumerTag s of
Just subscriber -> do
let msg = msgFromContentHeaderProperties properties body
let env = Envelope {envDeliveryTag = deliveryTag, envRedelivered = redelivered,
envExchangeName = exchange, envRoutingKey = routingKey, envChannel = chan}
CE.catch (subscriber (msg, env))
(\(e::CE.SomeException) -> putStrLn $ "AMQP callback threw exception: " ++ show e)
Nothing ->
return ()
)
handleAsync (SimpleMethod (Channel_close _ (ShortString errorMsg) _ _)) = do
closeChannel' chan errorMsg
killThread =<< myThreadId
handleAsync (SimpleMethod (Channel_flow active)) = do
if active
then openLock $ chanActive chan
else closeLock $ chanActive chan
return ()
--Basic.return
handleAsync (ContentMethod basicReturn@(Basic_return _ _ _ _) props body) = do
let msg = msgFromContentHeaderProperties props body
pubError = basicReturnToPublishError basicReturn
withMVar (returnListeners chan) $ \listeners ->
forM_ listeners $ \l -> CE.catch (l (msg, pubError)) $ \(ex :: CE.SomeException) ->
putStrLn $ "return listener on channel ["++(show $ channelID chan)++"] handling error ["++show pubError++"] threw exception: "++show ex
handleAsync m = error ("Unknown method: " ++ show m)
basicReturnToPublishError (Basic_return code (ShortString errText) (ShortString exchange) (ShortString routingKey)) =
let replyError = case code of
312 -> Unroutable errText
313 -> NoConsumers errText
404 -> NotFound errText
num -> error $ "unexpected return error code: " ++ (show num)
pubError = PublishError replyError (Just exchange) routingKey
in pubError
basicReturnToPublishError x = error $ "basicReturnToPublishError fail: "++show x
addReturnListener :: Channel -> ((Message, PublishError) -> IO ()) -> IO ()
addReturnListener chan listener = do
modifyMVar_ (returnListeners chan) $ \listeners -> return $ listener:listeners
closeChannel' :: Channel -> Text -> IO ()
closeChannel' c reason = do
modifyMVar_ (chanClosed c) $ \x -> do
if isNothing x
then do
modifyMVar_ (connChannels $ connection c) $ \old -> do
ret <- freeChannel (connChanAllocator $ connection c) $ fromIntegral $ channelID c
when (not ret) $ putStrLn "closeChannel error: channel already freed"
return $ IM.delete (fromIntegral $ channelID c) old
void $ killLock $ chanActive c
killOutstandingResponses $ outstandingResponses c
return $ Just $ maybe (T.unpack reason) id x
else return x
where
killOutstandingResponses :: MVar (Seq.Seq (MVar a)) -> IO ()
killOutstandingResponses outResps = do
modifyMVar_ outResps $ \val -> do
F.mapM_ (\x -> tryPutMVar x $ error "channel closed") val
return undefined
openChannel :: Connection -> IO Channel
openChannel c = do
newInQueue <- newChan
outRes <- newMVar Seq.empty
lastConsTag <- newMVar 0
ca <- newLock
closed <- newMVar Nothing
conss <- newMVar M.empty
listeners <- newMVar []
newChannel <- modifyMVar (connChannels c) $ \mp -> do
newChannelID <- allocateChannel (connChanAllocator c)
let newChannel = Channel c newInQueue outRes (fromIntegral newChannelID) lastConsTag ca closed conss listeners
thrID <- forkIO $ CE.finally (channelReceiver newChannel)
(closeChannel' newChannel "closed")
when (IM.member newChannelID mp) $ CE.throwIO $ userError "openChannel fail: channel already open"
return (IM.insert newChannelID (newChannel, thrID) mp, newChannel)
SimpleMethod (Channel_open_ok _) <- request newChannel (SimpleMethod (Channel_open (ShortString "")))
return newChannel
closeChannel :: Channel -> IO ()
closeChannel c = do
SimpleMethod Channel_close_ok <- request c $ SimpleMethod $ Channel_close 0 (ShortString "") 0 0
closeChannel' c "user called closeChannel"
writeFrames :: Channel -> [FramePayload] -> IO ()
writeFrames chan payloads =
let conn = connection chan in
withMVar (connChannels conn) $ \chans ->
if IM.member (fromIntegral $ channelID chan) chans
then
CE.catch
(do
withMVar (connWriteLock conn) $ \_ ->
mapM_ (\payload -> writeFrame (connHandle conn) (Frame (channelID chan) payload)) payloads
updateLastSent conn)
( \(_ :: CE.IOException) -> do
CE.throwIO $ userError "connection not open"
)
else do
CE.throwIO $ userError "channel not open"
writeAssembly' :: Channel -> Assembly -> IO ()
writeAssembly' chan (ContentMethod m properties msg) = do
waitLock $ chanActive chan
let !toWrite =
[(MethodPayload m),
(ContentHeaderPayload
(getClassIDOf properties) --classID
0
(fromIntegral $ BL.length msg) --bodySize
properties)] ++
(if BL.length msg > 0
then do
map ContentBodyPayload
(splitLen msg $ (fromIntegral $ connMaxFrameSize $ connection chan) 8)
else []
)
writeFrames chan toWrite
where
splitLen str len | BL.length str > len = (BL.take len str):(splitLen (BL.drop len str) len)
splitLen str _ = [str]
writeAssembly' chan (SimpleMethod m) = writeFrames chan [MethodPayload m]
writeAssembly :: Channel -> Assembly -> IO ()
writeAssembly chan m =
CE.catches
(writeAssembly' chan m)
[CE.Handler (\ (_ :: AMQPException) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.ErrorCall) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.IOException) -> throwMostRelevantAMQPException chan)]
request :: Channel -> Assembly -> IO Assembly
request chan m = do
res <- newEmptyMVar
CE.catches (do
withMVar (chanClosed chan) $ \cc -> do
if isNothing cc
then do
modifyMVar_ (outstandingResponses chan) $ \val -> return $! val Seq.|> res
writeAssembly' chan m
else CE.throwIO $ userError "closed"
!r <- takeMVar res
return r
)
[CE.Handler (\ (_ :: AMQPException) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.ErrorCall) -> throwMostRelevantAMQPException chan),
CE.Handler (\ (_ :: CE.IOException) -> throwMostRelevantAMQPException chan)]
throwMostRelevantAMQPException :: Channel -> IO a
throwMostRelevantAMQPException chan = do
cc <- readMVar $ connClosed $ connection chan
case cc of
Just r -> CE.throwIO $ ConnectionClosedException r
Nothing -> do
chc <- readMVar $ chanClosed chan
case chc of
Just r -> CE.throwIO $ ChannelClosedException r
Nothing -> CE.throwIO $ ConnectionClosedException "unknown reason"
data AMQPException =
ChannelClosedException String
| ConnectionClosedException String
deriving (Typeable, Show, Ord, Eq)
instance CE.Exception AMQPException