module Network.WebSockets (
shakeHands,
getRequest, putResponse,
getFrame, putFrame,
reqHost, reqPath, reqOrigin, reqLocation,
Request()) where
import System.IO (Handle, hPutChar, hFlush, hGetChar, hPutStr, hGetLine)
import Data.Binary (encode)
import Data.Int (Int32)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Digest.Pure.MD5 (md5)
import Data.Char (isDigit, chr, ord)
import Data.List (isPrefixOf, isSuffixOf)
import qualified Control.Exception as E
import qualified Data.Map as M
fromString = B.pack . map (fromIntegral.ord)
toString = map (chr.fromIntegral) . B.unpack
shakeHands :: Handle -> IO (Either String Request)
shakeHands h = do
request <- getRequest h
case request of
Right req -> do putResponse h req
return request
otherwise -> do return request
data Request = Request {
reqHost :: String,
reqPath :: String,
reqOrigin :: String,
reqKey1, reqKey2, reqToken :: String
}
reqLocation :: Request -> String
reqLocation r = "ws://" ++ reqHost r ++ reqPath r
instance Show Request where
show r = "requested " ++ reqLocation r ++ " from " ++ reqOrigin r
type RawRequest = M.Map String String
getRequest :: Handle -> IO (Either String Request)
getRequest h = do
first <- toString `fmap` B.hGetLine h
if "GET " `isPrefixOf` first && " HTTP/1.1\r" `isSuffixOf` first
then do (step . M.singleton "Path" $ words first !! 1)
`E.catch`
(\e -> return.Left $ show (e :: E.SomeException))
else return.Left $ "First line is not a valid GET request: " ++ show first
where
step :: RawRequest -> IO (Either String Request)
step req = do
line <- toString `fmap` B.hGetLine h
if null line
then return.Left $ "Got empty line in header: " ++ show line
else case break (==' ') (init line) of
("", "") -> do
bytes <- (map (chr.fromIntegral) . BL.unpack) `fmap` BL.hGet h 8
return . validateRequest $ M.insert "Token" bytes req
(key, val) ->
if key `elem` ["Host:", "Connection:", "Sec-WebSocket-Key1:",
"Sec-WebSocket-Key2:", "Upgrade:", "Origin:"]
then case M.lookup (init key) req of
Just _ -> return.Left $ "Duplicate key: " ++ show key
Nothing -> step $ M.insert (init key) (tail val) req
else return.Left $ "Unrecognized header key in line: " ++ show line
validateRequest :: RawRequest -> Either String Request
validateRequest req
| lacksHeaderKeys = Left $ "Bad request, keys missing: " ++ show req
| any faultyKey [1, 2] = Left $ "Faulty Sec-WebSocketKey: " ++ show req
| otherwise = Right $ fromRaw req
where
lacksHeaderKeys =
any (flip M.notMember req) ["Host", "Path", "Origin", "Sec-WebSocket-Key1",
"Sec-WebSocket-Key2", "Token"]
faultyKey n =
let key = req M.! ("Sec-WebSocket-Key" ++ show n)
in length (filter (==' ') key) == 0
fromRaw :: RawRequest -> Request
fromRaw r = Request { reqHost = r M.! "Host",
reqPath = r M.! "Path",
reqOrigin = r M.! "Origin",
reqKey1 = r M.! "Sec-WebSocket-Key1",
reqKey2 = r M.! "Sec-WebSocket-Key2",
reqToken = r M.! "Token"
}
putResponse :: Handle -> Request -> IO ()
putResponse h req = B.hPutStr h (createResponse req)
createResponse :: Request -> B.ByteString
createResponse req =
let header =
"HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
\Upgrade: WebSocket\r\n\
\Connection: Upgrade\r\n\
\Sec-WebSocket-Origin: " ++
reqOrigin req ++ "\r\n\
\Sec-WebSocket-Location: ws://" ++
reqHost req ++ reqPath req ++ "\r\n\
\Sec-WebSocket-Protocol: sample\r\n\r\n"
in B.append (fromString header) (createToken req)
createToken :: Request -> B.ByteString
createToken req =
let encodeAsInt = encode . divNumBySpaces
[num1, num2] = map encodeAsInt [reqKey1 req, reqKey2 req]
token = BL.pack . map (fromIntegral . ord) $ reqToken req
hash = md5 $ BL.concat [num1, num2, token]
in B.pack . BL.unpack $ encode hash
divNumBySpaces :: String -> Int32
divNumBySpaces str =
let number = read $ filter isDigit str :: Integer
spaces = fromIntegral . length $ filter (==' ') str
in fromIntegral $ number `div` spaces
putFrame :: Handle -> B.ByteString -> IO ()
putFrame h bs = do
B.hPutStr h . B.cons 0 $ B.snoc bs 255
hFlush h
getFrame :: Handle -> IO B.ByteString
getFrame h = do
first <- B.hGet h 1
if B.null first
then return first
else readUntil255 B.empty
where
readUntil255 buf = do
b <- B.hGet h 1
if B.null b
then return b
else let n = (B.unpack b) !! 0
in case n of
255 -> return buf
otherwise -> readUntil255 $ B.snoc buf n