module Network.Nats (
Nats
, NatsSID
, connect
, NatsException
, MsgCallback
, subscribe
, unsubscribe
, publish
, request
, disconnect
) where
import System.IO
import Control.Concurrent.MVar
import Control.Concurrent
import qualified Network.Socket as S
import Network.Socket (SocketOption(KeepAlive), setSocketOption, getAddrInfo, SockAddr(..))
import Control.Monad (forever, replicateM)
import Data.Dequeue as D
import Control.Applicative ((<$>))
import Data.Typeable
import qualified Data.Foldable as FOLD
import Control.Exception (bracket, bracketOnError, throwIO, catch, IOException, AsyncException, Exception)
import System.Random (randomRIO)
import Data.IORef
import qualified Data.Map.Strict as Map
import qualified Data.ByteString.Lazy.Char8 as BL
import qualified Data.ByteString.Char8 as BS
import Data.Char (toLower, isUpper)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8)
import qualified Data.Aeson as AE
import Data.Aeson.TH (deriveJSON, defaultOptions, fieldLabelModifier)
import qualified Network.URI as URI
data NatsException = NatsException String
deriving (Show, Typeable)
instance Exception NatsException
data NatsConnectionOptions = NatsConnectionOptions {
natsConnUser :: T.Text
, natsConnPass :: T.Text
, natsConnVerbose :: Bool
, natsConnPedantic :: Bool
, natsConnSslRequired :: Bool
} deriving (Show)
defaultConnectionOptions :: NatsConnectionOptions
defaultConnectionOptions = NatsConnectionOptions{natsConnUser="nats",natsConnPass="nats", natsConnVerbose=True,
natsConnPedantic=True, natsConnSslRequired=False}
$(deriveJSON defaultOptions{fieldLabelModifier =(
let insertUnderscore acc chr
| isUpper chr = chr : '_' : acc
| otherwise = chr : acc
in
map toLower . drop 1 . reverse . foldl insertUnderscore [] . drop 8
)} ''NatsConnectionOptions)
data NatsServerInfo = NatsServerInfo {
natsSvrServerId :: T.Text
, natsSvrVersion :: T.Text
, natsSvrMaxPayload :: Int
, natsSvrAuthRequired :: Bool
} deriving (Show)
$(deriveJSON defaultOptions{fieldLabelModifier =(
let insertUnderscore acc chr
| isUpper chr = chr : '_' : acc
| otherwise = chr : acc
in
map toLower . drop 1 . reverse . foldl insertUnderscore [] . drop 7
)} ''NatsServerInfo)
newtype NatsSID = NatsSID Int deriving (Num, Ord, Eq)
instance Show NatsSID where
show (NatsSID num) = show num
instance Read NatsSID where
readsPrec x1 x2 = map (\(a,rest) -> (NatsSID a, rest)) $ readsPrec x1 x2
type MsgCallback = NatsSID
-> String
-> BL.ByteString
-> Maybe String
-> IO ()
data NatsSubscription = NatsSubscription {
subSubject :: Subject
, subQueue :: Maybe Subject
, subCallback :: MsgCallback
, subSid :: NatsSID
}
type FifoQueue = D.BankersDequeue (Maybe T.Text -> IO ())
data Nats = Nats {
natsConnOptions :: NatsConnectionOptions
, natsHost :: String
, natsPort :: Int
, natsRuntime :: MVar (Handle,
FifoQueue,
Bool,
MVar ()
)
, natsThreadId :: MVar ThreadId
, natsNextSid :: IORef NatsSID
, natsSubMap :: IORef (Map.Map NatsSID NatsSubscription)
}
data NatsSvrMessage =
NatsSvrMsg { msgSubject::String, msgSid::NatsSID, msgText::BS.ByteString, msgReply::Maybe String}
| NatsSvrOK
| NatsSvrError T.Text
| NatsSvrPing
| NatsSvrPong
| NatsSvrInfo NatsServerInfo
deriving (Show)
newtype Subject = Subject String deriving (Show)
subjectToStr :: Subject -> String
subjectToStr (Subject str) = str
makeSubject :: String -> Subject
makeSubject str
| any (<=' ') str = error $ "Subject contains incorrect characters: " ++ str
| otherwise = Subject str
data NatsClntMessage =
NatsClntPing
| NatsClntPong
| NatsClntSubscribe Subject NatsSID (Maybe Subject)
| NatsClntUnsubscribe NatsSID
| NatsClntPublish Subject (Maybe Subject) BL.ByteString
| NatsClntConnect NatsConnectionOptions
makeClntMsg :: NatsClntMessage -> BL.ByteString
makeClntMsg = BL.fromChunks . _makeClntMsg
where
_makeClntMsg :: NatsClntMessage -> [BS.ByteString]
_makeClntMsg NatsClntPing = ["PING"]
_makeClntMsg NatsClntPong = ["PONG"]
_makeClntMsg (NatsClntSubscribe subject sid (Just queue)) = [BS.pack $ "SUB " ++ (subjectToStr subject) ++ " " ++ (subjectToStr queue) ++ " " ++ (show sid)]
_makeClntMsg (NatsClntSubscribe subject sid Nothing) = [BS.pack $ "SUB " ++ (subjectToStr subject) ++ " " ++ (show sid)]
_makeClntMsg (NatsClntUnsubscribe sid) = [ BS.pack $ "UNSUB " ++ (show sid) ]
_makeClntMsg (NatsClntPublish subj Nothing msg) =
(BS.pack $ "PUB " ++ (subjectToStr subj) ++ " " ++ (show $ BL.length msg) ++ "\r\n") : BL.toChunks msg
_makeClntMsg (NatsClntPublish subj (Just reply) msg) =
(BS.pack $ "PUB " ++ (subjectToStr subj) ++ " " ++ (subjectToStr reply) ++ " " ++ (show $ BL.length msg) ++ "\r\n") : BL.toChunks msg
_makeClntMsg (NatsClntConnect info) = "CONNECT " : (BL.toChunks $ AE.encode info)
decodeMessage :: BS.ByteString -> Maybe (NatsSvrMessage, Maybe Int)
decodeMessage line = decodeMessage_ mid mpayload
where
(mid, mpayload) = (BS.takeWhile (\x -> x/=' ' && x/='\r') line,
BS.drop 1 $ BS.dropWhile (\x -> x/=' ' && x/='\r') line)
decodeMessage_ :: BS.ByteString -> BS.ByteString -> Maybe (NatsSvrMessage, Maybe Int)
decodeMessage_ "PING" _ = Just (NatsSvrPing, Nothing)
decodeMessage_ "PONG" _ = Just (NatsSvrPong, Nothing)
decodeMessage_ "+OK" _ = Just (NatsSvrOK, Nothing)
decodeMessage_ "-ERR" msg = Just (NatsSvrError (decodeUtf8 msg), Nothing)
decodeMessage_ "INFO" msg = do
info <- AE.decode $ BL.fromChunks [msg]
return $ (NatsSvrInfo info, Nothing)
decodeMessage_ "MSG" msg = do
let fields = BS.split ' ' msg
case (map BS.unpack fields) of
[subj, sid, len] -> return (NatsSvrMsg subj (read sid) undefined Nothing, Just $ read len)
[subj, sid, reply, len] -> return (NatsSvrMsg subj (read sid) undefined (Just $ reply), Just $ read len)
_ -> fail ""
decodeMessage_ _ _ = Nothing
newNatsSid :: Nats -> IO NatsSID
newNatsSid nats = atomicModifyIORef' (natsNextSid nats) $ \sid -> (sid + 1, sid)
newInbox :: IO String
newInbox = do
rnd <- replicateM 13 (randomRIO ('a', 'z'))
return $ "_INBOX." ++ rnd
connectToServer :: String -> Int -> IO Handle
connectToServer hostname port = do
addrinfos <- getAddrInfo Nothing (Just hostname) Nothing
let serveraddr = (head addrinfos)
bracketOnError
(S.socket (S.addrFamily serveraddr) S.Stream S.defaultProtocol)
(S.sClose)
(\sock -> do
setSocketOption sock KeepAlive 1
let connaddr = case (S.addrAddress serveraddr) of
SockAddrInet _ haddr -> SockAddrInet (fromInteger $ toInteger port) haddr
SockAddrInet6 _ finfo haddr scopeid -> SockAddrInet6 (fromInteger $ toInteger port) finfo haddr scopeid
other -> other
S.connect sock connaddr
h <- S.socketToHandle sock ReadWriteMode
hSetBuffering h NoBuffering
return h
)
ensureConnection :: Nats -> Bool -> ((Handle, FifoQueue) -> IO FifoQueue) -> IO ()
ensureConnection nats True f = do
bracketOnError
(takeMVar $ natsRuntime nats)
(putMVar $ natsRuntime nats)
(\r@(handle, _, x1, x2) -> do
result <- runAction r
case result of
Just nqueue -> putMVar (natsRuntime nats) (handle, nqueue, x1, x2)
Nothing -> return ()
)
where
runAction (handle, queue, True, _) = do
nqueue <- f (handle, queue)
return $ Just nqueue
runAction r@(_, _, False, csig) = do
putMVar (natsRuntime nats) r
readMVar csig
ensureConnection nats True f
return Nothing
ensureConnection nats False f = modifyMVarMasked_ (natsRuntime nats) runAction
where
runAction (handle, queue, True, csig) = do
nqueue <- f (handle, queue)
return (handle, nqueue, True, csig)
runAction (handle, queue, False, csig) =
return (handle, queue, False, csig)
sendMessage :: Nats -> Bool -> NatsClntMessage -> Maybe (Maybe T.Text -> IO ()) -> IO ()
sendMessage nats blockIfDisconnected msg mcb
| Just cb <- mcb, supportsCallback msg =
ensureConnection nats blockIfDisconnected $ \(handle, queue) -> do
_sendMessage handle msg
return $ D.pushBack queue cb
| supportsCallback msg = sendMessage nats blockIfDisconnected msg (Just $ \_ -> return ())
| Just _ <- mcb, not (supportsCallback msg) = error "Callback not supported"
| True = ensureConnection nats blockIfDisconnected $ \(handle, queue) -> do
_sendMessage handle msg
return queue
where
supportsCallback (NatsClntConnect {}) = True
supportsCallback (NatsClntPublish {}) = True
supportsCallback (NatsClntSubscribe {}) = True
supportsCallback (NatsClntUnsubscribe {}) = True
supportsCallback _ = False
_sendMessage :: Handle -> NatsClntMessage -> IO ()
_sendMessage handle cmsg = do
BL.hPut handle $ makeClntMsg cmsg
BS.hPut handle "\r\n"
authenticate :: Nats -> Handle -> IO ()
authenticate nats handle = do
info <- BS.hGetLine handle
case (decodeMessage info) of
Just (NatsSvrInfo (NatsServerInfo {natsSvrAuthRequired=True}), Nothing) -> do
BL.hPut handle $ makeClntMsg (NatsClntConnect $ natsConnOptions nats)
BS.hPut handle "\r\n"
response <- BS.hGetLine handle
case (decodeMessage response) of
Just (NatsSvrOK, Nothing) -> return ()
Just (NatsSvrError err, Nothing)-> throwIO $ NatsException $ "Authentication error: " ++ (show err)
_ -> throwIO $ NatsException $ "Incorrect server response"
Just (NatsSvrInfo _, Nothing) -> return ()
_ -> throwIO $ NatsException "Incorrect input from server"
prepareConnection :: Nats -> IO ()
prepareConnection nats = do
handle <- connectToServer (natsHost nats) (natsPort nats)
authenticate nats handle
(_, _, _, csig) <- takeMVar (natsRuntime nats)
putMVar (natsRuntime nats) (handle, D.empty, True, undefined)
putMVar csig ()
connectionThread :: Nats -> IO ()
connectionThread nats = do
connectionHandler nats
`catch` errorHandler
`catch` finalHandler
where
finalize e = do
(handle, queue, _, _) <- takeMVar (natsRuntime nats)
finsignal <- newEmptyMVar
putMVar (natsRuntime nats) (undefined, undefined, False, finsignal)
hClose handle
FOLD.mapM_ (\f -> f $ Just (T.pack $ show e)) queue
errorHandler :: IOException -> IO ()
errorHandler e = do
finalize e
tryToConnect
connectionThread nats
where
tryToConnect = do
threadDelay 5000000
prepareConnection nats
`catch` ((\_ -> tryToConnect) :: IOException -> IO ())
`catch` ((\_ -> tryToConnect) :: NatsException -> IO ())
finalHandler :: AsyncException -> IO ()
finalHandler e = finalize e
connectionHandler :: Nats -> IO ()
connectionHandler nats = do
(handle, _, _, _) <- readMVar (natsRuntime nats)
subscriptions <- readIORef (natsSubMap nats)
FOLD.forM_ subscriptions $ \(NatsSubscription subject queue _ sid) ->
sendMessage nats True (NatsClntSubscribe subject sid queue) Nothing
forever $
let
popCb (h, queue, x1, x2) = return ((h, newq, x1, x2), item)
where
(item, newq) = D.popFront queue
handleMessage NatsSvrPing = sendMessage nats True NatsClntPong Nothing
handleMessage NatsSvrPong = return ()
handleMessage NatsSvrOK = do
cb <- modifyMVar (natsRuntime nats) $ popCb
case cb of
Just f -> f Nothing
Nothing -> return ()
handleMessage (NatsSvrError txt) = do
cb <- modifyMVar (natsRuntime nats) $ popCb
case cb of
Just f -> f $ Just txt
Nothing -> putStrLn $ show txt
handleMessage (NatsSvrInfo (NatsServerInfo {natsSvrAuthRequired=True})) = do
sendMessage nats True (NatsClntConnect $ natsConnOptions nats) Nothing
handleMessage (NatsSvrInfo _) = return ()
handleMessage (NatsSvrMsg {..}) = do
msubscription <- Map.lookup msgSid <$> readIORef (natsSubMap nats)
case msubscription of
Just subscription -> (subCallback subscription) msgSid msgSubject (BL.fromChunks [msgText]) msgReply
Nothing -> sendMessage nats True (NatsClntUnsubscribe msgSid) Nothing
in do
line <- BS.hGetLine handle
case (decodeMessage line) of
Just (msg, Nothing) -> do
handleMessage msg
Just (msg@(NatsSvrMsg {}), Just paylen) -> do
payload <- BS.hGet handle paylen
_ <- BS.hGet handle 2
handleMessage msg{msgText=payload}
_ ->
putStrLn $ "Incorrect message: " ++ (show line)
connect :: String
-> IO Nats
connect uri = do
let parsedUri = case (URI.parseURI uri) of
Just x -> x
Nothing -> error ("Error parsing NATS url: " ++ uri)
if URI.uriScheme parsedUri /= "nats:"
then error "Incorrect URL scheme"
else return ()
let (host, port, user, password) = case (URI.uriAuthority parsedUri) of
Just (URI.URIAuth {..}) -> (uriRegName,
read $ drop 1 uriPort,
takeWhile (\x -> x /= ':') uriUserInfo,
takeWhile (\x -> x /= '@') $ drop 1 $ dropWhile (\x -> x /= ':') uriUserInfo
)
Nothing -> error "Missing hostname section"
csig <- newEmptyMVar
mruntime <- newMVar (undefined, undefined, False, csig)
mthreadid <- newEmptyMVar
nextsid <- newIORef 1
submap <- newIORef Map.empty
let opts = defaultConnectionOptions{natsConnUser=T.pack user, natsConnPass=T.pack password}
let nats = Nats{
natsConnOptions=opts
, natsHost=host
, natsPort=port
, natsRuntime=mruntime
, natsThreadId=mthreadid
, natsNextSid=nextsid
, natsSubMap=submap
}
prepareConnection nats
threadid <- forkIO $ connectionThread nats
putMVar mthreadid threadid
return nats
subscribe :: Nats
-> String
-> (Maybe String)
-> MsgCallback
-> IO NatsSID
subscribe nats subject queue cb = do
let ssubject = makeSubject subject
let squeue = makeSubject `fmap` queue
mvar <- newEmptyMVar :: IO (MVar (Maybe T.Text))
sid <- newNatsSid nats
sendMessage nats True (NatsClntSubscribe ssubject sid squeue) $ Just $ \err -> do
case err of
Just _ -> return ()
Nothing -> atomicModifyIORef' (natsSubMap nats) $ \ioref ->
(Map.insert sid (NatsSubscription{subSubject=ssubject, subQueue=squeue, subCallback=cb, subSid=sid}) ioref, ())
putMVar mvar err
merr <- takeMVar mvar
case merr of
Just err -> throwIO $ NatsException $ T.unpack err
Nothing -> return $ sid
unsubscribe :: Nats
-> NatsSID
-> IO ()
unsubscribe nats sid = do
atomicModifyIORef' (natsSubMap nats) $ \ioref -> (Map.delete sid ioref, ())
sendMessage nats False (NatsClntUnsubscribe sid) Nothing
`catch` ((\_ -> return ()) :: IOException -> IO ())
request :: Nats
-> String
-> BL.ByteString
-> IO BL.ByteString
request nats subject body = do
mvar <- newEmptyMVar :: IO (MVar (Either String BL.ByteString))
inbox <- newInbox
bracket
(subscribe nats inbox Nothing $ \_ _ response _ -> do
_ <- tryPutMVar mvar (Right response)
return ()
)
(\sid -> unsubscribe nats sid)
(\_ -> do
sendMessage nats True (NatsClntPublish (makeSubject subject) (Just $ makeSubject inbox) body) $ Just $ \merr -> do
case merr of
Nothing -> return ()
Just err -> tryPutMVar mvar (Left $ T.unpack err) >> return ()
result <- takeMVar mvar
case result of
Left err -> throwIO $ NatsException err
Right res -> return $ res
)
publish :: Nats
-> String
-> BL.ByteString
-> IO ()
publish nats subject body = do
sendMessage nats False (NatsClntPublish (makeSubject subject) Nothing body) Nothing
`catch` ((\_ -> return()) :: IOException -> IO ())
disconnect :: Nats -> IO ()
disconnect nats = do
threadid <- readMVar (natsThreadId nats)
killThread threadid