module Network.WebSockets.Monad
( WebSocketsOptions (..)
, defaultWebSocketsOptions
, WebSockets (..)
, runWebSockets
, runWebSocketsWith
, runWebSocketsHandshake
, runWebSocketsWithHandshake
, runWebSocketsWith'
, receiveWith
, sendWith
, send
, Sink
, sendSink
, getSink
, getOptions
, getProtocol
, getVersion
, throwWsError
, catchWsError
, spawnPingThread
) where
import Control.Applicative (Applicative, (<$>))
import Control.Concurrent (forkIO, threadDelay)
import Control.Concurrent.MVar (newMVar, withMVar)
import Control.Exception (Exception (..), SomeException, throw)
import Control.Monad (forever)
import Control.Monad.Reader (ReaderT, ask, runReaderT)
import Control.Monad.State (StateT, evalStateT, get)
import Control.Monad.Trans (MonadIO, lift, liftIO)
import Blaze.ByteString.Builder (Builder)
import Blaze.ByteString.Builder.Enumerator (builderToByteString)
import Data.ByteString (ByteString)
import Data.Enumerator (Enumerator, Iteratee, ($$), (>>==))
import qualified Data.Attoparsec.Enumerator as AE
import qualified Data.Enumerator as E
import Network.WebSockets.Demultiplex (DemultiplexState, emptyDemultiplexState)
import Network.WebSockets.Handshake
import Network.WebSockets.Handshake.Http
import Network.WebSockets.Handshake.ShyIterParser
import Network.WebSockets.Mask
import Network.WebSockets.Protocol
import Network.WebSockets.Types as T
data WebSocketsOptions = WebSocketsOptions
{ onPong :: IO ()
}
defaultWebSocketsOptions :: WebSocketsOptions
defaultWebSocketsOptions = WebSocketsOptions
{ onPong = return ()
}
data WebSocketsEnv p = WebSocketsEnv
{ options :: WebSocketsOptions
, sendBuilder :: Builder -> IO ()
, protocol :: p
}
newtype WebSockets p a = WebSockets
{ unWebSockets :: ReaderT (WebSocketsEnv p)
(StateT DemultiplexState (Iteratee ByteString IO)) a
} deriving (Applicative, Functor, Monad, MonadIO)
runWebSocketsHandshake :: Protocol p
=> (Request -> WebSockets p a)
-> Iteratee ByteString IO ()
-> Iteratee ByteString IO a
runWebSocketsHandshake = runWebSocketsWithHandshake defaultWebSocketsOptions
runWebSocketsWithHandshake :: Protocol p
=> WebSocketsOptions
-> (Request -> WebSockets p a)
-> Iteratee ByteString IO ()
-> Iteratee ByteString IO a
runWebSocketsWithHandshake opts goWs outIter = do
httpReq <- receiveIteratee decodeRequest
runWebSocketsWith opts httpReq goWs outIter
runWebSockets :: Protocol p
=> RequestHttpPart
-> (Request -> WebSockets p a)
-> Iteratee ByteString IO ()
-> Iteratee ByteString IO a
runWebSockets = runWebSocketsWith defaultWebSocketsOptions
runWebSocketsWith :: forall p a. Protocol p
=> WebSocketsOptions
-> RequestHttpPart
-> (Request -> WebSockets p a)
-> Iteratee ByteString IO ()
-> Iteratee ByteString IO a
runWebSocketsWith opts httpReq goWs outIter = do
mreq <- receiveIterateeShy $ tryFinishRequest httpReq
case mreq of
(Left err) -> do
sendIteratee encodeResponse (responseError proto err) outIter
E.throwError err
(Right (r, p)) -> runWebSocketsWith' opts p (goWs r) outIter
where
proto :: p
proto = undefined
runWebSocketsWith' :: Protocol p
=> WebSocketsOptions
-> p
-> WebSockets p a
-> Iteratee ByteString IO ()
-> Iteratee ByteString IO a
runWebSocketsWith' opts proto ws outIter = do
sendLock <- liftIO $ newMVar ()
let sender = makeSend sendLock
env = WebSocketsEnv opts sender proto
state = runReaderT (unWebSockets ws) env
iter = evalStateT state emptyDemultiplexState
iter
where
makeSend sendLock x = withMVar sendLock $ \_ ->
builderSender outIter x
spawnPingThread :: BinaryProtocol p => Int -> WebSockets p ()
spawnPingThread i = do
sink <- getSink
_ <- liftIO $ forkIO $ forever $ do
threadDelay (i * 1000 * 1000)
sendSink sink $ ping ("Hi" :: ByteString)
return ()
receiveWith :: Decoder p a -> WebSockets p a
receiveWith = liftIteratee . receiveIteratee
receiveIteratee :: Decoder p a -> Iteratee ByteString IO a
receiveIteratee parser = do
eof <- E.isEOF
if eof
then E.throwError ConnectionClosed
else wrappingParseError . AE.iterParser $ parser
receiveIterateeShy :: Decoder p a -> Iteratee ByteString IO a
receiveIterateeShy parser = wrappingParseError $ shyIterParser parser
wrappingParseError :: (Monad m) => Iteratee a m b -> Iteratee a m b
wrappingParseError = flip E.catchError $ \e -> E.throwError $
maybe e (toException . ParseError) $ fromException e
sendIteratee :: Encoder p a -> a
-> Iteratee ByteString IO ()
-> Iteratee ByteString IO ()
sendIteratee enc resp outIter = do
liftIO $ mkSend (builderSender outIter) enc resp
sendWith :: Encoder p a -> a -> WebSockets p ()
sendWith encoder x = WebSockets $ do
send' <- sendBuilder <$> ask
liftIO $ mkSend send' encoder x
send :: Protocol p => T.Message p -> WebSockets p ()
send msg = getSink >>= \sink -> liftIO $ sendSink sink msg
newtype Sink p = Sink {unSink :: Message p -> IO ()}
sendSink :: Sink p -> Message p -> IO ()
sendSink = unSink
getSink :: Protocol p => WebSockets p (Sink p)
getSink = WebSockets $ do
proto <- unWebSockets getProtocol
send' <- sendBuilder <$> ask
return $ Sink $ mkSend send' $ encodeMessage $ encodeFrame proto
where
encodeMessage frame mask msg = frame mask $ case msg of
(ControlMessage (Close pl)) -> Frame True CloseFrame pl
(ControlMessage (Ping pl)) -> Frame True PingFrame pl
(ControlMessage (Pong pl)) -> Frame True PongFrame pl
(DataMessage (Text pl)) -> Frame True TextFrame pl
(DataMessage (Binary pl)) -> Frame True BinaryFrame pl
mkSend :: (Builder -> IO ()) -> Encoder p a -> a -> IO ()
mkSend send' encoder x = do
mask <- randomMask
send' $ encoder mask x
singleton :: Monad m => a -> Enumerator a m b
singleton c = E.checkContinue0 $ \_ f -> f (E.Chunks [c]) >>== E.returnI
builderSender :: MonadIO m => Iteratee ByteString m b -> Builder -> m ()
builderSender outIter x = do
ok <- E.run $ singleton x $$ builderToByteString $$ outIter
case ok of
Left err -> throw err
Right _ -> return ()
getOptions :: WebSockets p WebSocketsOptions
getOptions = WebSockets $ ask >>= return . options
getProtocol :: WebSockets p p
getProtocol = WebSockets $ protocol <$> ask
getVersion :: Protocol p => WebSockets p String
getVersion = version <$> getProtocol
throwWsError :: (Exception e) => e -> WebSockets p a
throwWsError = liftIteratee . E.throwError
catchWsError :: WebSockets p a
-> (SomeException -> WebSockets p a)
-> WebSockets p a
catchWsError act c = WebSockets $ do
env <- ask
state <- get
let it = peelWebSockets state env $ act
cit = peelWebSockets state env . c
lift . lift $ it `E.catchError` cit
where
peelWebSockets state env =
flip evalStateT state . flip runReaderT env . unWebSockets
liftIteratee :: Iteratee ByteString IO a -> WebSockets p a
liftIteratee = WebSockets . lift . lift