module Network.WebSockets (
  shakeHands, 
  getRequest, putResponse, 
  getFrame, putFrame, 
  reqHost, reqPath, reqOrigin, reqLocation,
  Request()) where

import System.IO (Handle, hPutChar, hFlush, hGetChar, hPutStr, hGetLine)
import Data.Binary (encode)
import Data.Int (Int32)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Digest.Pure.MD5 (md5)
import Data.Char (isDigit, chr, ord)
import Data.List (isPrefixOf, isSuffixOf)
import qualified Control.Exception as E
import qualified Data.Map as M


-- Quick and dirty String<->ByteString conversion.
fromString = B.pack . map (fromIntegral.ord)
toString = map (chr.fromIntegral) . B.unpack


{- | Accept and perform a handshake, no matter the request contents.

As long as the request is well-formed, the client will receive a response saying, essentially, \"proceed\". Use this function if you don't care who you're connected to, as long as that someone speaks the WebSocket protocol.

The function returns either a String in case of error, or a 'Request' on success. The 'Request' is returned purely for logging purposes, since the handshake has already been executed. Use this function immediately after establishing the WebSocket connection.

If you wish not to blindly accept requests but to filter them according to their contents, use the 'getRequest' and 'putResponse' functions. -}
shakeHands :: Handle -> IO (Either String Request)
shakeHands h = do
  request <- getRequest h
  case request of
    Right req -> do putResponse h req
                    return request
    otherwise -> do return request -- return error
                    

{- | Contains the request details, accessible via the 'reqHost', 'reqPath', 'reqOrigin' and 'reqLocation' functions. -}
data Request = Request {
  reqHost :: String, -- ^ Returns the requested host. 
  reqPath :: String, -- ^ Returns the requested path.
  reqOrigin :: String, -- ^ Returns the origin of the request.
  reqKey1, reqKey2, reqToken :: String
}

{- | Returns the requested location. Equal to @(\\r -> \"ws:\/\/\" ++ reqHost r ++ reqPath r)@. -}
reqLocation :: Request -> String
reqLocation r = "ws://" ++ reqHost r ++ reqPath r

instance Show Request where
  show r = "requested " ++ reqLocation r ++ " from " ++ reqOrigin r

{- Contains the client's request. The eight-byte token is under key \"Token\", while the requested path is under key \"Path\". Others are the same as in the request header: \"Origin\", \"Upgrade\" and \"Sec-WebSocket-Key2\", to name a few. -}
type RawRequest = M.Map String String


{- | Reads the client's opening handshake and returns either a 'Request' based on its contents, or a String in case of an error. -}
getRequest :: Handle -> IO (Either String Request)
getRequest h = do
  -- the first line should be a "GET :path: HTTP/1.1             
  first <- toString `fmap` B.hGetLine h
  if "GET " `isPrefixOf` first && " HTTP/1.1\r" `isSuffixOf` first
    then do (step . M.singleton "Path" $ words first !! 1)
            `E.catch` 
            (\e -> return.Left $ show (e :: E.SomeException))
                
    else return.Left $ "First line is not a valid GET request: " ++ show first
  
  where
    -- reads and stores all of the header values, stopping when
    -- it encounters an unrecognized header key, duplicate header keys
    -- or an empty line followed by the eight-byte token.
    step :: RawRequest -> IO (Either String Request)
    step req = do
      line <- toString `fmap` B.hGetLine h
      if null line
        then return.Left $ "Got empty line in header: " ++ show line
        else case break (==' ') (init line) of
          ("", "") -> do
            -- we skip this empty line and read the next 8 bytes, the token
            bytes <- (map (chr.fromIntegral) . BL.unpack) `fmap` BL.hGet h 8
            return . validateRequest $ M.insert "Token" bytes req
          (key, val) ->
            -- can we recognize the header key? raise an error if not.
            -- also, raise an error if duplicate header keys are read.
            if key `elem` ["Host:", "Connection:", "Sec-WebSocket-Key1:",
                           "Sec-WebSocket-Key2:", "Upgrade:", "Origin:"]
              then case M.lookup (init key) req of
                     Just _ -> return.Left $ "Duplicate key: " ++ show key
                     Nothing -> step $ M.insert (init key) (tail val) req
              else return.Left $ "Unrecognized header key in line: " ++ show line
                   
    

{- Checks if a given raw request is valid or not. A valid request won't cause a division by zero when calculating a response token and contains all the neccessary data to create a response. Returns either an error if the request is not valid, or a valid, final request. -}
validateRequest :: RawRequest -> Either String Request
validateRequest req 
  | lacksHeaderKeys = Left $ "Bad request, keys missing: " ++ show req
  | any faultyKey [1, 2] = Left $ "Faulty Sec-WebSocketKey: " ++ show req
  | otherwise = Right $ fromRaw req
  where
    lacksHeaderKeys = 
      any (flip M.notMember req) ["Host", "Path", "Origin", "Sec-WebSocket-Key1", 
                                  "Sec-WebSocket-Key2", "Token"]
    faultyKey n = 
      let key = req M.! ("Sec-WebSocket-Key" ++ show n)
      in  length (filter (==' ') key) == 0
    
    fromRaw :: RawRequest -> Request
    fromRaw r = Request { reqHost = r M.! "Host",
                          reqPath = r M.! "Path",
                          reqOrigin = r M.! "Origin",
                          reqKey1 = r M.! "Sec-WebSocket-Key1",
                          reqKey2 = r M.! "Sec-WebSocket-Key2",
                          reqToken = r M.! "Token"
                        }

            
{- | Sends an accepting response based on the given 'Request', thus accepting and ending the handshake. -}
putResponse :: Handle -> Request -> IO ()
putResponse h req = B.hPutStr h (createResponse req)


{- Returns an accepting response based on the given 'Request'. -}
createResponse :: Request -> B.ByteString
createResponse req =
  let header =
        "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
        \Upgrade: WebSocket\r\n\
        \Connection: Upgrade\r\n\
        \Sec-WebSocket-Origin: " ++ 
         reqOrigin req ++ "\r\n\
        \Sec-WebSocket-Location: ws://" ++ 
         reqHost req ++ reqPath req ++ "\r\n\
        \Sec-WebSocket-Protocol: sample\r\n\r\n"
  in  B.append (fromString header) (createToken req)

                                              
{- Constructs the response token by using the two security keys and eight-byte token, as defined by the protocol -}
createToken :: Request -> B.ByteString
createToken req =
  let encodeAsInt = encode . divNumBySpaces 
      [num1, num2] = map encodeAsInt [reqKey1 req, reqKey2 req]
      token = BL.pack . map (fromIntegral . ord) $ reqToken req
      hash = md5 $ BL.concat [num1, num2, token]
  in  B.pack . BL.unpack $ encode hash


{- Divides the number hiding in the string by the number of spaces in the string, as defined in the protocol. Assumes division by zero will not occur, since the request was verified to be valid beforehand. -}
divNumBySpaces :: String -> Int32
divNumBySpaces str =
  let number = read $ filter isDigit str :: Integer
      spaces = fromIntegral . length $ filter (==' ') str
  in  fromIntegral $ number `div` spaces


{- | Send a strict ByteString. Call this function only after having performed the handshake. -}
putFrame :: Handle -> B.ByteString -> IO ()
putFrame h bs = do
  B.hPutStr h . B.cons 0 $ B.snoc bs 255
  hFlush h
  

{- | Receive a strict ByteString. Call this function only after having performed the handshake. This function will block until an entire frame is read. If the writing end of the handle is closed, the function returns an empty ByteString. -}
getFrame :: Handle -> IO B.ByteString
getFrame h = do
  first <- B.hGet h 1 -- assume this is 0
  if B.null first
     then return first
     else readUntil255 B.empty
  where
    readUntil255 buf = do
      b <- B.hGet h 1
      if B.null b
         then return b
         else let n = (B.unpack b) !! 0
              in  case n of
                    255 -> return buf
                    otherwise -> readUntil255 $ B.snoc buf n