module Network.Wai.Handler.WebSockets
( intercept
, interceptWith
) where
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import Data.Char (toLower)
import Network.Wai (Request, requestHeaders, rawPathInfo, requestHeaders)
import qualified Data.ByteString.Char8 as S
import qualified Data.Conduit as C
import qualified Data.Enumerator as E
import qualified Data.Enumerator.List as EL
import qualified Network.WebSockets as WS
import Network.Wai.Handler.Warp (Connection (..))
intercept :: WS.Protocol p
=> (WS.Request -> WS.WebSockets p ())
-> Request
-> Maybe (C.BufferedSource IO ByteString -> Connection -> C.ResourceT IO ())
intercept = interceptWith WS.defaultWebSocketsOptions
interceptWith :: WS.Protocol p
=> WS.WebSocketsOptions
-> (WS.Request -> WS.WebSockets p ())
-> Request
-> Maybe (C.BufferedSource IO ByteString -> Connection -> C.ResourceT IO ())
interceptWith opts app req = case lookup "upgrade" $ requestHeaders req of
Just s
| S.map toLower s == "websocket" -> Just $ runWebSockets opts req' app
| otherwise -> Nothing
_ -> Nothing
where
req' = WS.RequestHttpPart (rawPathInfo req) (requestHeaders req)
runWebSockets :: WS.Protocol p
=> WS.WebSocketsOptions
-> WS.RequestHttpPart
-> (WS.Request -> WS.WebSockets p ())
-> C.BufferedSource IO ByteString
-> Connection
-> C.ResourceT IO ()
runWebSockets opts req app source conn = do
step <- liftIO $ E.runIteratee $ WS.runWebSocketsWith opts req app send
source C.$$ C.sinkState (E.returnI step) push close
where
send = EL.mapM_ (connSendAll conn)
close = const $ return ()
push iter bs = do
step <- liftIO $ E.runIteratee $ E.enumList 1 [bs] E.$$ iter
return $ C.StateProcessing $ E.returnI step