module Network.Stack.Server where
import Control.Concurrent.Async
import Control.Concurrent.MVar
import qualified Control.Exception as E
import Control.Monad
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as BS
import qualified Data.ByteString.Builder.Extra as BS
import qualified Data.ByteString.Lazy as BSL
import Data.Int
import Data.Typeable
import qualified Data.X509 as X509
import qualified Network.TLS as TLS
import qualified Network.WebSockets as WS
import qualified Network.WebSockets.Stream as WS
import qualified System.Socket as S
import qualified System.Socket.Type.Stream as S
data TLS a
data WebSocket a
class (Typeable a) => ServerStack a where
data Server a
data ServerConfig a
data ServerException a
data ServerConnection a
data ServerConnectionInfo a
withServer :: ServerConfig a -> (Server a -> IO b) -> IO b
withConnection :: Server a -> (ServerConnection a -> ServerConnectionInfo a -> IO b) -> IO (Async b)
class ServerStack a => StreamServerStack a where
sendStream :: ServerConnection a -> BS.ByteString -> IO Int
sendStream server bs = fromIntegral <$> sendStreamLazy server (BSL.fromStrict bs)
sendStreamLazy :: ServerConnection a -> BSL.ByteString -> IO Int64
sendStreamLazy server = foldM
(\sent bs-> sendStream server bs >>= \sent'-> pure $! sent + fromIntegral sent') 0 . BSL.toChunks
sendStreamBuilder :: ServerConnection a -> Int -> BS.Builder -> IO Int64
sendStreamBuilder server chunksize = sendStreamLazy server
. BS.toLazyByteStringWith (BS.untrimmedStrategy chunksize chunksize) mempty
receiveStream :: ServerConnection a -> Int -> IO BS.ByteString
receiveStream server i = BSL.toStrict <$> receiveStreamLazy server i
receiveStreamLazy :: ServerConnection a -> Int -> IO BSL.ByteString
receiveStreamLazy server i = BSL.fromStrict <$> receiveStream server i
class ServerStack a => MessageServerStack a where
type ClientMessage a
type ServerMessage a
sendMessage :: ServerConnection a -> ServerMessage a -> IO Int64
sendMessages :: Foldable t => ServerConnection a -> t (ServerMessage a) -> IO Int64
receiveMessage :: ServerConnection a -> Int64 -> IO (ClientMessage a)
consumeMessages :: ServerConnection a -> Int64 -> (ClientMessage a -> IO Bool) -> IO ()
instance (Typeable f, Typeable p, S.Family f, S.Protocol p) => StreamServerStack (S.Socket f S.Stream p) where
sendStream (SocketServerConnection s) bs = S.sendAll s bs S.msgNoSignal
sendStreamLazy (SocketServerConnection s) lbs = S.sendAllLazy s lbs S.msgNoSignal
sendStreamBuilder (SocketServerConnection s) bufsize builder = S.sendAllBuilder s bufsize builder S.msgNoSignal
receiveStream (SocketServerConnection s) i = S.receive s i S.msgNoSignal
instance (StreamServerStack a) => StreamServerStack (TLS a) where
sendStreamLazy connection lbs = TLS.sendData (tlsContext connection) lbs >> pure (BSL.length lbs)
receiveStream connection _ = TLS.recvData (tlsContext connection)
instance (StreamServerStack a) => StreamServerStack (WebSocket a) where
sendStream connection bs = WS.sendBinaryData (wsConnection connection) bs >> pure (BS.length bs)
sendStreamLazy connection lbs = WS.sendBinaryData (wsConnection connection) lbs >> pure (BSL.length lbs)
receiveStreamLazy connection _ = WS.receiveData (wsConnection connection)
instance (S.Family f, S.Type t, S.Protocol p, Typeable f, Typeable t, Typeable p) => ServerStack (S.Socket f t p) where
data Server (S.Socket f t p) = SocketServer
{ socketServer :: !(S.Socket f t p)
, socketServerConfig :: !(ServerConfig (S.Socket f t p))
}
data ServerConfig (S.Socket f t p) = SocketServerConfig
{ socketServerConfigBindAddress :: !(S.SocketAddress f)
, socketServerConfigListenQueueSize :: Int
}
data ServerException (S.Socket f t p) = SocketServerException !S.SocketException
data ServerConnection (S.Socket f t p) = SocketServerConnection !(S.Socket f t p)
data ServerConnectionInfo (S.Socket f t p) = SocketServerConnectionInfo !(S.SocketAddress f)
withServer c handle = E.bracket
(SocketServer <$> S.socket <*> pure c)
(S.close . socketServer) $ \server-> do
S.setSocketOption (socketServer server) (S.ReuseAddress True)
S.bind (socketServer server) (socketServerConfigBindAddress $ socketServerConfig server)
S.listen (socketServer server) (socketServerConfigListenQueueSize $ socketServerConfig server)
handle server
withConnection server handle =
E.bracketOnError (S.accept (socketServer server)) (S.close . fst) $ \(connection, addr)->
async (handle (SocketServerConnection connection) (SocketServerConnectionInfo addr) `E.finally` S.close connection)
instance (StreamServerStack a, Typeable a) => ServerStack (TLS a) where
data Server (TLS a) = TlsServer
{ tlsTransportServer :: Server a
, tlsServerConfig :: ServerConfig (TLS a)
}
data ServerConfig (TLS a) = TlsServerConfig
{ tlsTransportConfig :: ServerConfig a
, tlsServerParams :: TLS.ServerParams
}
data ServerException (TLS a) =
TlsServerEndOfStreamException
deriving (Eq, Ord, Show)
data ServerConnection (TLS a) = TlsServerConnection
{ tlsTransportConnection :: ServerConnection a
, tlsContext :: TLS.Context
}
data ServerConnectionInfo (TLS a) = TlsServerConnectionInfo
{ tlsTransportServerConnectionInfo :: ServerConnectionInfo a
, tlsCertificateChain :: Maybe X509.CertificateChain
}
withServer config handle =
withServer (tlsTransportConfig config) $ \server->
handle (TlsServer server config)
withConnection server handle =
withConnection (tlsTransportServer server) $ \connection info-> do
let backend = TLS.Backend {
TLS.backendFlush = pure ()
, TLS.backendClose = pure ()
, TLS.backendSend = void . sendStream connection
, TLS.backendRecv = flip (receiveExactly connection) mempty
}
mvar <- newEmptyMVar
let srvParams = tlsServerParams $ tlsServerConfig server
srvParams' = srvParams {
TLS.serverHooks = (TLS.serverHooks srvParams) {
TLS.onClientCertificate = \certChain-> do
putMVar mvar certChain
pure TLS.CertificateUsageAccept
}
}
context <- TLS.contextNew backend srvParams'
TLS.handshake context
certificateChain <- tryTakeMVar mvar
x <- handle
(TlsServerConnection connection context)
(TlsServerConnectionInfo info certificateChain)
TLS.bye context
pure x
where
receiveExactly connection bytes accum = do
bs <- receiveStream connection bytes
when (BS.null bs) $
E.throwIO (TlsServerEndOfStreamException :: ServerException (TLS a))
if BS.length bs < bytes
then receiveExactly connection (bytes BS.length bs) $! accum `mappend` bs
else pure $! accum `mappend` bs
instance (StreamServerStack a) => ServerStack (WebSocket a) where
data Server (WebSocket a) = WebSocketServer
{ wsTransportServer :: Server a
}
data ServerConfig (WebSocket a) = WebSocketServerConfig
{ wsTransportConfig :: ServerConfig a
}
data ServerException (WebSocket a) = WebSocketServerException
data ServerConnection (WebSocket a) = WebSocketServerConnection
{ wsTransportConnection :: ServerConnection a
, wsConnection :: WS.Connection
}
data ServerConnectionInfo (WebSocket a) = WebSocketServerConnectionInfo
{ wsTransportServerConnectionInfo :: ServerConnectionInfo a
, wsRequestHead :: WS.RequestHead
}
withServer config handle =
withServer (wsTransportConfig config) $ \server->
handle (WebSocketServer server)
withConnection server handle =
withConnection (wsTransportServer server) $ \connection info-> do
let readSocket = (\bs-> if BS.null bs then Nothing else Just bs) <$> receiveStream connection 4096
let writeSocket Nothing = pure ()
writeSocket (Just bs) = void (sendStream connection (BSL.toStrict bs))
stream <- WS.makeStream readSocket writeSocket
pendingConnection <- WS.makePendingConnectionFromStream stream (WS.ConnectionOptions $ pure ())
acceptedConnection <- WS.acceptRequestWith pendingConnection (WS.AcceptRequest (Just "mqtt") [])
x <- handle
(WebSocketServerConnection connection acceptedConnection)
(WebSocketServerConnectionInfo info $ WS.pendingRequest pendingConnection)
WS.sendClose acceptedConnection ("Thank you for flying Haskell." :: BS.ByteString)
pure x
deriving instance Show (S.SocketAddress f) => Show (ServerConnectionInfo (S.Socket f t p))
deriving instance Show (ServerConnectionInfo a) => Show (ServerConnectionInfo (TLS a))
deriving instance Show (ServerConnectionInfo a) => Show (ServerConnectionInfo (WebSocket a))
instance Typeable a => E.Exception (ServerException (TLS a))