module Network.Stomp (
Command (..),
ClientCommand (..),
ServerCommand (..),
Frame (..),
Connection,
StompUri,
Host,
Destination,
MessageId,
Transaction,
Subscription,
Version,
StompException (..),
connect,
stomp,
connect',
stomp',
disconnect,
send,
subscribe,
unsubscribe,
ack,
nack,
begin,
commit,
abort,
startConsumer,
receiveFrame,
setExcpHandler,
startSendBeat,
startRecvBeat,
sendTimeout,
recvTimeout,
lastSend,
lastRecv,
versions
)
where
import qualified Data.ByteString.UTF8 as BU
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.List (intercalate)
import Data.Time (UTCTime, getCurrentTime, diffUTCTime)
import Data.Typeable
import Data.Maybe
import Data.Char (toLower)
import qualified Data.Binary.Builder as B
import Data.Binary.Put
import Network.BSD
import Network.URI
import qualified Network.Socket as N
import qualified Control.Exception as E
import Control.Concurrent
import Control.Monad
import System.IO
data Frame = Frame {
command :: Command,
headers :: [Header],
body :: BL.ByteString
} deriving Show
data ClientCommand
= SEND
| SUBSCRIBE
| UNSUBSCRIBE
| BEGIN
| COMMIT
| ABORT
| ACK
| NACK
| DISCONNECT
| CONNECT
| STOMP
deriving (Show, Read, Eq)
data ServerCommand
= CONNECTED
| MESSAGE
| RECEIPT
| ERROR
deriving (Show, Read, Eq)
data Command
= CC ClientCommand
| SC ServerCommand
deriving (Show, Read, Eq)
type Header = (String, String)
data Connection = Connection {
handle :: Handle,
versions :: [Version],
listener :: MVar ThreadId,
sendBeat :: MVar ThreadId,
recvBeat :: MVar ThreadId,
sendTimeout :: Int,
recvTimeout :: Int,
lastSend :: MVar UTCTime,
lastRecv :: MVar UTCTime,
sockLock :: MVar (),
closed :: MVar (Maybe StompException),
disconReq :: MVar String,
disconResp :: MVar (),
excpHandle :: MVar (Maybe (StompException -> IO ()))
}
data StompException
= ConnectionError String
| InvalidUri String
| InvalidFrame String
| BrokerError String
| StompIOError E.IOException
deriving (Typeable, Show, Eq)
instance E.Exception StompException
type StompUri = String
type Host = String
type Destination = String
type MessageId = String
type Transaction = String
type Subscription = String
type Version = (Int,Int)
connect :: StompUri -> [Version] -> [Header] -> IO Connection
connect uri vs hs = do
(host, port, hs') <- processUri uri
connect' host port vs (hs ++ hs')
stomp :: StompUri -> [Header] -> IO Connection
stomp uri hs = do
(host, port, hs') <- processUri uri
stomp' host port (hs ++ hs')
connect' :: Host -> PortNumber -> [Version] -> [Header] -> IO Connection
connect' h p [] hs = mkConnection CONNECT h p (hs ++ [("accept-version","1.0")])
connect' h p vs hs = mkConnection CONNECT h p hs'
where hs' = hs ++ [("accept-version", vers)]
vers = intercalate "," $ map go vs
go (v,v') = show v ++ '.':show v'
stomp' :: Host -> PortNumber -> [Header] -> IO Connection
stomp' = mkConnection STOMP
disconnect :: Connection -> [Header] -> IO ()
disconnect con hs = do
let receipt = lookup "receipt" hs
case receipt of
Nothing -> do stop con
sendFrame con (Frame (CC DISCONNECT) hs BL.empty)
Just r -> do putMVar (disconReq con) r
sendFrame con (Frame (CC DISCONNECT) hs BL.empty)
readMVar $ disconResp con
stop con
hClose $ handle con
modifyMVar_ (closed con) $ \x ->
return (Just $ fromMaybe (ConnectionError "Connection closed") x)
send :: Connection -> Destination -> [Header] -> BL.ByteString -> IO ()
send con dest hs body = sendFrame con (Frame (CC SEND) hs' body)
where hs' = hs ++ [("destination", dest)] ++ clh
clh = maybe hdr (\_->[]) $ lookup "content-length" hs
hdr = [("content-length", show $ BL.length body)]
subscribe :: Connection -> Destination -> Subscription -> [Header] -> IO ()
subscribe con dest id hs = sendFrame con (Frame (CC SUBSCRIBE) hs' BL.empty)
where hs' = hs ++ [("destination", dest), ("id", id)]
unsubscribe :: Connection -> Subscription -> [Header] -> IO ()
unsubscribe con id hs = sendFrame con (Frame (CC UNSUBSCRIBE) hs' BL.empty)
where hs' = hs ++ [("id", id)]
ack :: Connection -> Subscription -> MessageId -> [Header] -> IO ()
ack con id msgid hs = sendFrame con (Frame (CC ACK) hs' BL.empty)
where hs' = hs ++ [("subscription", id), ("message-id", msgid)]
nack :: Connection -> Subscription -> MessageId -> [Header] -> IO ()
nack con id msgid hs = sendFrame con (Frame (CC NACK) hs' BL.empty)
where hs' = hs ++ [("subscription", id), ("message-id", msgid)]
begin :: Connection -> Transaction -> [Header] -> IO ()
begin con tid hs = sendFrame con (Frame (CC BEGIN) hs' BL.empty)
where hs' = hs ++ [("transaction", tid)]
commit :: Connection -> Transaction -> [Header] -> IO ()
commit con tid hs = sendFrame con (Frame (CC COMMIT) hs' BL.empty)
where hs' = hs ++ [("transaction", tid)]
abort :: Connection -> Transaction -> [Header] -> IO ()
abort con tid hs = sendFrame con (Frame (CC ABORT) hs' BL.empty)
where hs' = hs ++ [("transaction", tid)]
mkConnection :: ClientCommand -> Host -> PortNumber -> [Header] -> IO Connection
mkConnection cmd host port hs = do
con <- newConn host port hs
sendFrame con (Frame (CC cmd) hs BL.empty)
(Frame (SC cmd) hs' body) <- receiveFrame con
when (cmd /= CONNECTED) $
E.throwIO $ ConnectionError (BL.unpack body)
let (sendBeat, recvBeat) = getBeats hs hs'
return con {recvTimeout = recvBeat, sendTimeout = sendBeat, versions = ver hs'}
where ver h = maybe [(1,0)] parseVer $ lookup "version" h
parseVer :: String -> [Version]
parseVer vs
| null r = [mkV v]
| otherwise = mkV v : parseVer (tail r)
where (v,r) = break (==',') vs
mkV s = let (m, m') = break (=='.') s in
(read m, read (tail m'))
openSocket :: String -> PortNumber -> IO Handle
openSocket host port = do
proto <- getProtocolNumber "tcp"
sock <- N.socket N.AF_INET N.Stream proto
addr <- N.inet_addr host
N.connect sock (N.SockAddrInet port addr)
h <- N.socketToHandle sock ReadWriteMode
hSetBuffering h (BlockBuffering Nothing)
return h
newConn :: Host -> PortNumber -> [Header] -> IO Connection
newConn host port hs = do
h <- openSocket host port
lstnr <- newEmptyMVar
sBeat <- newEmptyMVar
rBeat <- newEmptyMVar
lSend <- newEmptyMVar
lRecv <- newEmptyMVar
sLock <- newMVar ()
close <- newMVar Nothing
dRcpt <- newEmptyMVar
dLock <- newEmptyMVar
eHndl <- newMVar Nothing
return $ Connection h [] lstnr sBeat rBeat 0 0 lSend lRecv sLock close dRcpt dLock eHndl
stompAuth :: String -> Maybe URIAuth
stompAuth str = maybe Nothing auth (parseURI str)
where auth (URI s a _ _ _) =
if map toLower s /= "stomp:" then Nothing else a
fromAuth :: PortNumber -> URIAuth -> (Host, PortNumber, [Header])
fromAuth defPort ua =
(host, portNum, [("login", user), ("passcode", pwd')])
where (user, pwd) = break (==':') (uriUserInfo ua)
pwd' = if null pwd then [] else drop 1 $ init pwd
port = drop 1 (uriPort ua)
portNum = if null port then defPort
else toEnum (read port :: Int)
host = uriRegName ua
processUri :: String -> IO (Host, PortNumber, [Header])
processUri uri =
maybe (E.throwIO $ InvalidUri uri)
(return . fromAuth defaultPort)
(stompAuth uri)
stop :: Connection -> IO ()
stop c = mapM_ go [listener c, sendBeat c, recvBeat c]
where go x = do
t <- tryTakeMVar x
maybe (return ()) killThread t
defaultPort :: PortNumber
defaultPort = 61613
setExcpHandler :: Connection -> (StompException -> IO ()) -> IO ()
setExcpHandler con fun =
modifyMVar_ (excpHandle con) $ \_ -> return $ Just fun
onException :: Connection -> StompException -> IO ()
onException con e =
withMVar (excpHandle con) $ \h -> do
maybe (return ()) (\f -> f e) h
modifyMVar_ (closed con) $ \x ->
return $ mplus x (Just e)
startConsumer :: Connection -> (Frame -> IO ()) -> IO ()
startConsumer c fun = do
t <- tryTakeMVar (listener c)
when (isNothing t) $ do
tid <- forkIO $ E.finally
(consumeFrames c fun)
(tryPutMVar (disconResp c) ())
tryPutMVar (listener c) tid
return ()
consumeFrames :: Connection -> (Frame -> IO ()) -> IO ()
consumeFrames con fun =
E.catch
(do
frame <- receiveFrame con
fun frame
rec <- tryTakeMVar (disconReq con)
unless (isJust rec && checkReceipt (fromJust rec) frame) $ do
maybe (return ()) (putMVar (disconReq con)) rec
consumeFrames con fun)
(\(e::StompException) -> onException con e)
checkReceipt :: String -> Frame -> Bool
checkReceipt rec (Frame (SC RECEIPT) hs _) =
maybe False (== rec) $ lookup "receipt-id" hs
checkReceipt _ _ = False
sendFrame :: Connection -> Frame -> IO ()
sendFrame con f = sendBuf con (strict $ runPut $ putFrame f)
where
putFrame frame@(Frame (CC cmd) hs body) = do
mapM_ (putByteString . BU.fromString)
[show cmd, "\n", hdrToStr cmd hs, "\n"]
unless (BL.null body) $
putLazyByteString body
putWord8 0x00
strict x = BS.concat (BL.toChunks x)
hdrToStr :: ClientCommand -> [Header] -> String
hdrToStr _ [] = []
hdrToStr CONNECT hs = hdrToStr' id hs
hdrToStr cmd hs = hdrToStr' esc hs
hdrToStr' f xs = unlines $ map go xs
where go (x,y) = f x ++ ":" ++ f y
sendBuf :: Connection -> BS.ByteString -> IO ()
sendBuf con bs =
withMVar (closed con) $ \c ->
if isJust c then E.throwIO (fromJust c)
else
E.catch
(withMVar (sockLock con) $ \_ -> do
BS.hPut (handle con) bs
hFlush (handle con)
beatTime (lastSend con))
(\(e :: E.IOException) -> E.throwIO $ StompIOError e)
receiveFrame :: Connection -> IO Frame
receiveFrame con = do
cmd <- readCommand con
headers <- readHeaders con cmd
body <- readBody con headers
beatTime (lastRecv con)
return (Frame cmd headers body)
readLine :: Handle -> IO String
readLine h = fmap BU.toString (BS.hGetLine h)
readCommand :: Connection -> IO Command
readCommand con = do
eof <- hIsEOF (handle con)
if eof then
E.throwIO $ ConnectionError "Connection closed by broker"
else do
l <- readLine (handle con)
let l' = dropWhile (=='\x00') l
if null l' then do
beatTime (lastRecv con)
readCommand con
else return $ SC (read l' :: ServerCommand)
readHeaders :: Connection -> Command -> IO [Header]
readHeaders con cmd = do
l <- readLine (handle con)
if null l then return []
else do hs <- readHeaders con cmd
return (header (unesc l):hs)
case cmd of
(SC CONNECTED) -> return (header l:hs)
_ -> return (header (unesc l):hs)
where
header x = let (name, val) = break (==':') x
in (name, tail val)
readBody :: Connection -> [Header] -> IO BL.ByteString
readBody con hs = maybe (readTill h) (readBuf h) len
where
h = handle con
len = lookup "content-length" hs
readBuf h x = do
bs <- BL.hGet h (read x :: Int)
tr <- BL.hGet h 1
when (tr /= term) $
E.throwIO $ InvalidFrame "Missing frame terminator"
return bs
readTill h = liftM B.toLazyByteString (readTill' h)
readTill' h = do
ch <- BL.hGet h 1
if ch == term then return B.empty
else liftM (B.append $ B.fromLazyByteString ch) (readTill' h)
term :: BL.ByteString
term = BL.pack "\x00"
esc :: String -> String
esc = concatMap (\c -> fromMaybe [c] (code c))
where code x = lookup x escMap
escMap = zip ":\n\\" ["\\c","\\n","\\\\"]
unesc :: String -> String
unesc [] = []
unesc [x] = [x]
unesc (x:x':xs)
| [x,x'] == "\\n" = '\n' : unesc xs
| [x,x'] == "\\c" = ':' : unesc xs
| [x,x'] == "\\\\" = '\\' : unesc xs
| otherwise = x:unesc(x':xs)
getBeats :: [Header] -> [Header] -> (Int, Int)
getBeats xs ys = (getBeat cs sr, 2 * getBeat cr ss)
where
getBeat x y
| x /= 0 && y /= 0 = max x y
| otherwise = 0
(cs,cr) = beat xs
(ss,sr) = beat ys
beat :: [Header] -> (Int,Int)
beat = maybe (0,0) parseBeat . lookup "heart-beat"
where
parseBeat = parse . break (==',')
parse (x,y) = (read x :: Int, read (tail y) :: Int)
beatTime :: MVar UTCTime -> IO ()
beatTime v = do
now <- getCurrentTime
b <- tryPutMVar v now
unless b $ modifyMVar_ v $ \_ -> return now
clientBeat :: Connection -> IO ()
clientBeat con =
E.catch
(do
prev <- readMVar (lastSend con)
now <- getCurrentTime
let diff = floor $ diffUTCTime now prev * 1000
let delay = sendTimeout con
if diff >= delay then do
sendBuf con (BU.fromString "\n\x00")
threadDelay (1000 * delay)
else threadDelay (1000 * (delay diff))
clientBeat con)
(\(e::StompException) -> onException con e)
brokerBeat :: Connection -> IO ()
brokerBeat con =
E.catch
(do
prev <- readMVar (lastRecv con)
now <- getCurrentTime
let diff = floor $ diffUTCTime now prev * 1000
let delay = recvTimeout con
if diff >= delay then
E.throwIO $ BrokerError "Broker timeout expired"
else threadDelay (1000 * (delay diff))
brokerBeat con)
(\(e::StompException) -> onException con e)
startSendBeat :: Connection -> IO ()
startSendBeat c = do
svar <- tryTakeMVar (sendBeat c)
when (sendTimeout c > 0 && isNothing svar) $ do
tid <- forkIO $ E.finally
(clientBeat c)
(tryPutMVar (disconResp c) ())
tryPutMVar (sendBeat c) tid
return ()
startRecvBeat :: Connection -> IO ()
startRecvBeat c = do
rvar <- tryTakeMVar (recvBeat c)
when (recvTimeout c > 0 && isNothing rvar) $ do
tid <- forkIO $ E.finally
(brokerBeat c)
(tryPutMVar (disconResp c) ())
tryPutMVar (recvBeat c) tid
return ()