{-# LANGUAGE InstanceSigs #-}

module Simplex.Messaging.Transport.WebSockets (WS (..)) where

import qualified Control.Exception as E
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Network.Socket (Socket)
import Network.WebSockets
import Network.WebSockets.Stream (Stream)
import qualified Network.WebSockets.Stream as S
import Simplex.Messaging.Transport (TProxy, Transport (..), TransportError (..), trimCR)

data WS = WS {WS -> Stream
wsStream :: Stream, WS -> Connection
wsConnection :: Connection}

websocketsOpts :: ConnectionOptions
websocketsOpts :: ConnectionOptions
websocketsOpts =
  ConnectionOptions
defaultConnectionOptions
    { connectionCompressionOptions :: CompressionOptions
connectionCompressionOptions = CompressionOptions
NoCompression,
      connectionFramePayloadSizeLimit :: SizeLimit
connectionFramePayloadSizeLimit = Int64 -> SizeLimit
SizeLimit Int64
8192,
      connectionMessageDataSizeLimit :: SizeLimit
connectionMessageDataSizeLimit = Int64 -> SizeLimit
SizeLimit Int64
65536
    }

instance Transport WS where
  transportName :: TProxy WS -> String
  transportName :: TProxy WS -> String
transportName TProxy WS
_ = String
"WebSockets"

  getServerConnection :: Socket -> IO WS
  getServerConnection :: Socket -> IO WS
getServerConnection Socket
sock = do
    Stream
s <- Socket -> IO Stream
S.makeSocketStream Socket
sock
    Stream -> Connection -> WS
WS Stream
s (Connection -> WS) -> IO Connection -> IO WS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stream -> IO Connection
acceptClientRequest Stream
s
    where
      acceptClientRequest :: Stream -> IO Connection
      acceptClientRequest :: Stream -> IO Connection
acceptClientRequest Stream
s = Stream -> ConnectionOptions -> IO PendingConnection
makePendingConnectionFromStream Stream
s ConnectionOptions
websocketsOpts IO PendingConnection
-> (PendingConnection -> IO Connection) -> IO Connection
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= PendingConnection -> IO Connection
acceptRequest

  getClientConnection :: Socket -> IO WS
  getClientConnection :: Socket -> IO WS
getClientConnection Socket
sock = do
    Stream
s <- Socket -> IO Stream
S.makeSocketStream Socket
sock
    Stream -> Connection -> WS
WS Stream
s (Connection -> WS) -> IO Connection -> IO WS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stream -> IO Connection
sendClientRequest Stream
s
    where
      sendClientRequest :: Stream -> IO Connection
      sendClientRequest :: Stream -> IO Connection
sendClientRequest Stream
s = Stream
-> String
-> String
-> ConnectionOptions
-> Headers
-> IO Connection
newClientConnection Stream
s String
"" String
"/" ConnectionOptions
websocketsOpts []

  closeConnection :: WS -> IO ()
  closeConnection :: WS -> IO ()
closeConnection = Stream -> IO ()
S.close (Stream -> IO ()) -> (WS -> Stream) -> WS -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WS -> Stream
wsStream

  cGet :: WS -> Int -> IO ByteString
  cGet :: WS -> Int -> IO ByteString
cGet WS
c Int
n = do
    ByteString
s <- Connection -> IO ByteString
forall a. WebSocketsData a => Connection -> IO a
receiveData (WS -> Connection
wsConnection WS
c)
    if ByteString -> Int
B.length ByteString
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n
      then ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
s
      else TransportError -> IO ByteString
forall e a. Exception e => e -> IO a
E.throwIO TransportError
TEBadBlock

  cPut :: WS -> ByteString -> IO ()
  cPut :: WS -> ByteString -> IO ()
cPut = Connection -> ByteString -> IO ()
forall a. WebSocketsData a => Connection -> a -> IO ()
sendBinaryData (Connection -> ByteString -> IO ())
-> (WS -> Connection) -> WS -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WS -> Connection
wsConnection

  getLn :: WS -> IO ByteString
  getLn :: WS -> IO ByteString
getLn WS
c = do
    ByteString
s <- ByteString -> ByteString
trimCR (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO ByteString
forall a. WebSocketsData a => Connection -> IO a
receiveData (WS -> Connection
wsConnection WS
c)
    if ByteString -> Bool
B.null ByteString
s Bool -> Bool -> Bool
|| ByteString -> Char
B.last ByteString
s Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'\n'
      then TransportError -> IO ByteString
forall e a. Exception e => e -> IO a
E.throwIO TransportError
TEBadBlock
      else ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.init ByteString
s