module Network.EngineIO
(
initialize
, handler
, EngineIO
, ServerAPI (..)
, SocketApp(..)
, send
, receive
, Socket
, SocketId
, socketId
, getOpenSockets
, dupRawReader
, Packet(..)
, parsePacket
, encodePacket
, PacketType
, PacketContent(..)
, Payload(..)
, parsePayload
, encodePayload
, TransportType(..)
, parseTransportType
) where
import Prelude hiding (any)
import Control.Applicative
import Control.Concurrent (threadDelay)
import Control.Concurrent.MVar (MVar, newMVar, withMVar)
import Control.Exception (SomeException(SomeException), try)
import Control.Monad (MonadPlus, forever, guard, mzero, replicateM, when)
import Control.Monad.Trans.Iter (cutoff, delay, retract)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Loops (unfoldM)
import Control.Monad.Trans.Class (lift)
import Control.Monad.Trans.Either (eitherT, left)
import Control.Monad.Trans.Maybe (runMaybeT)
import Data.Aeson ((.=))
import Data.Char (digitToInt, intToDigit)
import Data.Foldable (asum, for_)
import Data.Function (fix, on)
import Data.Ix (inRange)
import Data.List (foldl')
import Data.Monoid ((<>), mconcat, mempty)
import Data.Ord (comparing)
import Data.Traversable (for)
import qualified Control.Concurrent.Async as Async
import qualified Control.Concurrent.STM as STM
import qualified Control.Concurrent.STM.Delay as STMDelay
import qualified Data.Aeson as Aeson
import qualified Data.Attoparsec.ByteString as Attoparsec
import qualified Data.Attoparsec.ByteString.Char8 as AttoparsecC8
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Builder as Builder
import qualified Data.ByteString.Char8 as BSChar8
import qualified Data.ByteString.Lazy as LBS
import qualified Data.HashMap.Strict as HashMap
import qualified Data.Text as Text
import qualified Data.Text.Encoding as Text
import qualified Data.Vector as V
import qualified Network.WebSockets as WebSockets
import qualified Network.WebSockets.Connection as WebSockets
import qualified System.Random.MWC as Random
data PacketType = Open | Close | Ping | Pong | Message | Upgrade | Noop
deriving (Bounded, Enum, Eq, Read, Show)
packetTypeToIndex :: Num i => PacketType -> i
packetTypeToIndex t =
case t of
Open -> 0
Close -> 1
Ping -> 2
Pong -> 3
Message -> 4
Upgrade -> 5
Noop -> 6
packetTypeFromIndex :: (Eq i, MonadPlus m, Num i) => i -> m PacketType
packetTypeFromIndex i =
case i of
0 -> return Open
1 -> return Close
2 -> return Ping
3 -> return Pong
4 -> return Message
5 -> return Upgrade
6 -> return Noop
_ -> mzero
data Packet = Packet !PacketType !PacketContent
deriving (Eq, Show)
data PacketContent
= BinaryPacket !BS.ByteString
| TextPacket !Text.Text
deriving (Eq, Show)
parsePacket :: Attoparsec.Parser Packet
parsePacket = parsePacket' Attoparsec.takeByteString
parsePacket' :: Attoparsec.Parser BS.ByteString -> Attoparsec.Parser Packet
parsePacket' body = parseBase64 <|> parseBinary <|> parseText
where
parseBase64 = do
_ <- AttoparsecC8.char 'b'
Packet <$> c8PacketType
<*> (either fail (return . BinaryPacket) . Base64.decode =<< body)
parseBinary = do
Packet <$> (packetTypeFromIndex =<< Attoparsec.satisfy (inRange (0, 6)))
<*> (BinaryPacket <$> body)
parseText = do
Packet <$> c8PacketType
<*> (TextPacket . Text.decodeUtf8 <$> body)
c8PacketType =
packetTypeFromIndex . digitToInt =<< AttoparsecC8.satisfy (inRange ('0', '6'))
encodePacket
:: Bool
-> Packet
-> Builder.Builder
encodePacket True (Packet t (BinaryPacket bytes)) =
Builder.word8 (packetTypeToIndex t) <>
Builder.byteString bytes
encodePacket False (Packet t (BinaryPacket bytes)) =
Builder.char8 'b' <>
Builder.char8 (intToDigit (packetTypeToIndex t)) <>
Builder.byteString (Base64.encode bytes)
encodePacket _ (Packet t (TextPacket bytes)) =
Builder.char8 (intToDigit (packetTypeToIndex t)) <>
Builder.byteString (Text.encodeUtf8 bytes)
newtype Payload = Payload (V.Vector Packet)
deriving (Eq, Show)
parsePayload :: Attoparsec.Parser Payload
parsePayload = Payload <$> (goXHR2 <|> goXHR)
where
goXHR = do
len <- AttoparsecC8.decimal <* AttoparsecC8.char ':'
packet <- parsePacket' (Attoparsec.take (len 1))
(V.singleton packet <$ Attoparsec.endOfInput) <|> (V.cons packet <$> goXHR)
goXHR2 = do
_ <- Attoparsec.satisfy (`elem` [0, 1])
len <- parseLength =<< Attoparsec.many1 (Attoparsec.satisfy (inRange (0, 9)))
_ <- Attoparsec.word8 maxBound
packet <- parsePacket' (Attoparsec.take (len 1))
(V.singleton packet <$ Attoparsec.endOfInput) <|> (V.cons packet <$> goXHR2)
parseLength bytes = do
guard (length bytes <= 319)
return $ foldl' (\n x -> n * 10 + x) 0 $ map fromIntegral bytes
encodePayload
:: Bool
-> Payload
-> Builder.Builder
encodePayload supportsBinary (Payload packets) =
let encodeOne packet =
let bytes = encodePacket supportsBinary packet
in mconcat [ Builder.word8 $ if isBinaryPacket packet then 1 else 0
, mconcat $ map (Builder.word8 . read . pure) $
show (LBS.length (Builder.toLazyByteString bytes))
, Builder.word8 maxBound
, bytes
]
in V.foldl' (\bytes p -> bytes <> encodeOne p) mempty packets
where
isBinaryPacket (Packet _ (BinaryPacket _)) = True
isBinaryPacket _ = False
data TransportType
= Polling
| Websocket
deriving (Eq, Show)
instance Aeson.ToJSON TransportType where
toJSON t = Aeson.toJSON $ (`asTypeOf` show t) $
case t of
Polling -> "polling"
Websocket -> "websocket"
parseTransportType :: Text.Text -> Maybe TransportType
parseTransportType t =
case t of
"polling" -> Just Polling
"websocket" -> Just Websocket
_ -> Nothing
type SocketId = BS.ByteString
data Transport = Transport
{ transIn :: STM.TChan Packet
, transOut :: STM.TChan Packet
, transType :: !TransportType
}
data Socket = Socket
{ socketId :: !SocketId
, socketTransport :: STM.TVar Transport
, socketIncomingMessages :: STM.TChan PacketContent
, socketOutgoingMessages :: STM.TChan PacketContent
, socketRawIncomingBroadcast :: STM.TChan Packet
}
instance Eq Socket where
(==) = (==) `on` socketId
instance Ord Socket where
compare = comparing socketId
receive :: Socket -> STM.STM PacketContent
receive Socket{..} = STM.readTChan socketIncomingMessages
send :: Socket -> PacketContent -> STM.STM ()
send Socket{..} = STM.writeTChan socketOutgoingMessages
data ServerAPI m = ServerAPI
{ srvGetQueryParams :: m (HashMap.HashMap BS.ByteString [BS.ByteString])
, srvTerminateWithResponse :: Int -> BS.ByteString -> Builder.Builder -> forall a . m a
, srvParseRequestBody :: forall a. Attoparsec.Parser a -> m (Either String a)
, srvGetRequestMethod :: m BS.ByteString
, srvRunWebSocket :: WebSockets.ServerApp -> m ()
}
data EngineIO = EngineIO
{ eioOpenSessions :: STM.TVar (HashMap.HashMap SocketId Socket)
, eioRng :: MVar Random.GenIO
}
initialize :: IO EngineIO
initialize =
EngineIO
<$> STM.newTVarIO mempty
<*> (Random.createSystemRandom >>= newMVar)
getOpenSockets :: EngineIO -> STM.STM (HashMap.HashMap SocketId Socket)
getOpenSockets = STM.readTVar . eioOpenSessions
data EngineIOError = BadRequest | TransportUnknown | SessionIdUnknown
deriving (Bounded, Enum, Eq, Show)
data SocketApp = SocketApp
{ saApp :: IO ()
, saOnDisconnect :: IO ()
}
handler :: MonadIO m => EngineIO -> (Socket -> m SocketApp) -> ServerAPI m -> m ()
handler eio socketHandler api@ServerAPI{..} = do
queryParams <- srvGetQueryParams
eitherT (serveError api) return $ do
reqTransport <- maybe (left TransportUnknown) return $ do
[t] <- HashMap.lookup "transport" queryParams
parseTransportType (Text.decodeUtf8 t)
socket <-
for (HashMap.lookup "sid" queryParams) $ \sids -> do
sid <- case sids of
[sid] -> return sid
_ -> left SessionIdUnknown
mSocket <- liftIO (STM.atomically (HashMap.lookup sid <$> getOpenSockets eio))
case mSocket of
Nothing -> left SessionIdUnknown
Just s -> return s
supportsBinary <-
case HashMap.lookup "b64" queryParams of
Just ["1"] -> return False
Just ["0"] -> return True
Nothing -> return True
_ -> left BadRequest
case socket of
Just s -> do
transport <- liftIO $ STM.atomically $ STM.readTVar (socketTransport s)
case transType transport of
Polling
| reqTransport == Polling -> lift (handlePoll api transport supportsBinary)
| reqTransport == Websocket -> lift (upgrade api s)
_ -> left BadRequest
Nothing ->
lift (freshSession eio socketHandler api supportsBinary)
freshSession
:: MonadIO m
=> EngineIO
-> (Socket -> m SocketApp)
-> ServerAPI m
-> Bool
-> m ()
freshSession eio socketHandler api supportsBinary = do
socket <- do
mkSocket <- liftIO $ do
transport <- STM.newTVarIO =<< (Transport <$> STM.newTChanIO <*> STM.newTChanIO <*> pure Polling)
incoming <- STM.newTChanIO
outgoing <- STM.newTChanIO
rawInBroadcast <- STM.newBroadcastTChanIO
return (\sId -> Socket sId transport incoming outgoing rawInBroadcast)
let
tryAllocation = liftIO $ do
sId <- newSocketId eio
STM.atomically $ runMaybeT $ do
openSessions <- lift (STM.readTVar (eioOpenSessions eio))
guard (not (HashMap.member sId openSessions))
let socket = mkSocket sId
lift (STM.modifyTVar' (eioOpenSessions eio) (HashMap.insert sId socket))
return socket
untilSuccess f = maybe (delay (untilSuccess f)) return =<< f
maybeSocket <- retract (cutoff 10 (untilSuccess tryAllocation))
maybe (srvTerminateWithResponse api 500 "text/plain" "Session allocation failed")
return maybeSocket
app <- socketHandler socket
userSpace <- liftIO $ Async.async (saApp app)
pingTimeoutDelay <- liftIO $ STMDelay.newDelay (pingTimeout * 1000000)
heartbeat <- liftIO $ Async.async $
STM.atomically (STMDelay.waitDelay pingTimeoutDelay)
brain <- liftIO $ Async.async $ fix $ \loop -> do
mMessage <- STM.atomically $ do
transport <- STM.readTVar (socketTransport socket)
asum
[ do req <- STM.readTChan (transIn transport)
case req of
Packet Message m ->
STM.writeTChan (socketIncomingMessages socket) m
Packet Ping m ->
STM.writeTChan (transOut transport) (Packet Pong m)
_ ->
return ()
STM.writeTChan (socketRawIncomingBroadcast socket) req
return (Just req)
, do STM.readTChan (socketOutgoingMessages socket)
>>= STM.writeTChan (transOut transport) . Packet Message
return Nothing
]
for_ mMessage (const (STMDelay.updateDelay pingTimeoutDelay (pingTimeout * 1000000)))
case mMessage of
Just (Packet Close _) -> return ()
_ -> loop
_ <- liftIO $ Async.async $ do
_ <- Async.waitAnyCatchCancel [ userSpace, brain, heartbeat ]
STM.atomically (STM.modifyTVar' (eioOpenSessions eio) (HashMap.delete (socketId socket)))
saOnDisconnect app
let openMessage = OpenMessage { omSocketId = socketId socket
, omUpgrades = [ Websocket ]
, omPingTimeout = pingTimeout * 1000
, omPingInterval = 25000
}
payload = Payload $ V.singleton $
Packet Open (TextPacket $ Text.decodeUtf8 $ LBS.toStrict $ Aeson.encode openMessage)
writeBytes api (encodePayload supportsBinary payload)
where
pingTimeout = 60
upgrade :: MonadIO m => ServerAPI m -> Socket -> m ()
upgrade ServerAPI{..} socket = srvRunWebSocket go
where
go pending = do
conn <- WebSockets.acceptRequest $
pending { WebSockets.pendingOnAccept = (const $ return ()) }
mWsTransport <- runMaybeT $ do
Packet Ping (TextPacket "probe") <- lift (receivePacket conn)
lift (sendPacket conn (Packet Pong (TextPacket "probe")))
(wsIn, wsOut) <- liftIO $ STM.atomically $ do
currentTransport <- STM.readTVar (socketTransport socket)
return (transIn currentTransport, transOut currentTransport)
check <-
liftIO
(Async.async
(do threadDelay 100000
STM.atomically
(do t <- STM.readTVar (socketTransport socket)
when (transType t == Polling)
(STM.writeTChan (transOut t)
(Packet Noop (TextPacket Text.empty))))))
Packet Upgrade body <- lift (receivePacket conn)
guard (body == TextPacket Text.empty || body == BinaryPacket BS.empty)
liftIO (Async.cancel check)
return (Transport wsIn wsOut Websocket)
for_ mWsTransport $ \wsTransport@Transport { transIn = wsIn, transOut = wsOut } -> do
STM.atomically (STM.writeTVar (socketTransport socket) wsTransport)
reader <- Async.async $ forever $ do
p <- STM.atomically (STM.readTChan wsOut)
sendPacket conn p
fix $ \loop -> do
e <- try (receivePacket conn >>= STM.atomically . STM.writeTChan wsIn)
case e of
Left (SomeException _) ->
return ()
Right _ -> loop
Async.cancel reader
STM.atomically (STM.writeTChan wsIn (Packet Close (TextPacket Text.empty)))
receivePacket conn = do
msg <- WebSockets.receiveDataMessage conn
case msg of
WebSockets.Text bytes ->
case Attoparsec.parseOnly parsePacket (LBS.toStrict bytes) of
Left ex -> do
putStrLn $ "Malformed packet received: " ++ show bytes ++ " (" ++ show ex ++ ")"
receivePacket conn
Right p -> return p
other -> do
putStrLn $ "Unknown WebSocket message: " ++ show other
receivePacket conn
sendPacket conn (Packet t (TextPacket text)) =
WebSockets.sendTextData conn $
Text.encodeUtf8 $
Text.pack (pure $ intToDigit (packetTypeToIndex t)) <> text
sendPacket conn p@(Packet _ (BinaryPacket _)) = do
WebSockets.sendBinaryData conn (Builder.toLazyByteString (encodePacket True p))
handlePoll :: MonadIO m => ServerAPI m -> Transport -> Bool -> m ()
handlePoll api@ServerAPI{..} transport supportsBinary = do
requestMethod <- srvGetRequestMethod
case requestMethod of
m | m == "GET" -> poll
m | m == "POST" -> post
_ -> serveError api BadRequest
where
poll = do
readTimeout <- liftIO $ STM.registerDelay (45 * 1000000)
let out = transOut transport
packets <- liftIO $ do
p <- STM.atomically $ do
let dequeueHead = Just <$> STM.readTChan out
timeout = Nothing <$ (STM.readTVar readTimeout >>= STM.check)
dequeueHead <|> timeout
case p of
Just p' ->
(p' :) <$> unfoldM (STM.atomically (STM.tryReadTChan (transOut transport)))
Nothing ->
return [ Packet Ping (BinaryPacket mempty) ]
writeBytes api (encodePayload supportsBinary (Payload (V.fromList packets)))
post = do
p <- srvParseRequestBody parsePayload
case p of
Left ex -> do
liftIO $ putStrLn $ "WARNING: Parse failure in Network.EngineIO.handlePoll: " ++ show ex
srvTerminateWithResponse 400 "text/plain" "Empty request body"
Right (Payload packets) ->
liftIO $ STM.atomically (V.mapM_ (STM.writeTChan (transIn transport)) packets)
writeBytes :: Monad m => ServerAPI m -> Builder.Builder -> m a
writeBytes ServerAPI {..} builder = do
srvTerminateWithResponse 200 "application/octet-stream" builder
newSocketId :: EngineIO -> IO SocketId
newSocketId eio =
Base64.encode . BS.pack
<$> withMVar (eioRng eio) (replicateM 15 . Random.uniformR (0, 63))
data OpenMessage = OpenMessage
{ omSocketId :: !SocketId
, omUpgrades :: [TransportType]
, omPingTimeout :: !Int
, omPingInterval :: !Int
}
instance Aeson.ToJSON OpenMessage where
toJSON OpenMessage {..} = Aeson.object
[ "sid" .= Text.decodeUtf8 omSocketId
, "upgrades" .= omUpgrades
, "pingTimeout" .= omPingTimeout
, "pingInterval" .= omPingInterval
]
serveError :: Monad m => ServerAPI m -> EngineIOError -> m a
serveError ServerAPI{..} e = srvTerminateWithResponse 400 "application/json" $
Builder.lazyByteString $ Aeson.encode $ Aeson.object
[ "code" .= errorCode, "message" .= errorMessage ]
where
errorCode :: Int
errorCode = case e of
TransportUnknown -> 0
SessionIdUnknown -> 1
BadRequest -> 3
errorMessage :: Text.Text
errorMessage = case e of
TransportUnknown -> "Transport unknown"
SessionIdUnknown -> "Session ID unknown"
BadRequest -> "Bad request"
dupRawReader :: Socket -> IO (STM.STM Packet)
dupRawReader s = do
c <- STM.atomically (STM.dupTChan (socketRawIncomingBroadcast s))
return (STM.readTChan c)