{-# LANGUAGE NoImplicitPrelude #-}

-- |
--    Wuss is a library that lets you easily create secure WebSocket clients over
--    the WSS protocol. It is a small addition to
--    <https://hackage.haskell.org/package/websockets the websockets package>
--    and is adapted from existing solutions by
--    <https://gist.github.com/jaspervdj/7198388 @jaspervdj>,
--    <https://gist.github.com/mpickering/f1b7ba3190a4bb5884f3 @mpickering>, and
--    <https://gist.github.com/elfenlaid/7b5c28065e67e4cf0767 @elfenlaid>.
--
--    == Example
--
--    > import Wuss
--    >
--    > import Control.Concurrent (forkIO)
--    > import Control.Monad (forever, unless, void)
--    > import Data.Text (Text, pack)
--    > import Network.WebSockets (ClientApp, receiveData, sendClose, sendTextData)
--    >
--    > main :: IO ()
--    > main = runSecureClient "echo.websocket.org" 443 "/" ws
--    >
--    > ws :: ClientApp ()
--    > ws connection = do
--    >     putStrLn "Connected!"
--    >
--    >     void . forkIO . forever $ do
--    >         message <- receiveData connection
--    >         print (message :: Text)
--    >
--    >     let loop = do
--    >             line <- getLine
--    >             unless (null line) $ do
--    >                 sendTextData connection (pack line)
--    >                 loop
--    >     loop
--    >
--    >     sendClose connection (pack "Bye!")
--
--    == Retry
--
--    Note that it is possible for the connection itself or any message to fail and need to be retried.
--    Fortunately this can be handled by something like <https://hackage.haskell.org/package/retry the retry package>.
--    See <https://github.com/tfausak/wuss/issues/18#issuecomment-990921703 this comment> for an example.
module Wuss
  ( runSecureClient,
    runSecureClientWith,
    Config (..),
    defaultConfig,
    runSecureClientWithConfig,
  )
where

import qualified Control.Applicative as Applicative
import qualified Control.Exception as Exception
import qualified Control.Monad.Catch as Catch
import qualified Control.Monad.IO.Class as MonadIO
import qualified Data.Bool as Bool
import qualified Data.ByteString as StrictBytes
import qualified Data.ByteString.Lazy as LazyBytes
import qualified Data.Maybe as Maybe
import qualified Data.String as String
import qualified Network.Connection as Connection
import qualified Network.Socket as Socket
import qualified Network.WebSockets as WebSockets
import qualified Network.WebSockets.Stream as Stream
import qualified System.IO as IO
import qualified System.IO.Error as IO.Error
import Prelude (($), (.))

-- |
--    A secure replacement for 'Network.WebSockets.runClient'.
--
--    >>> let app _connection = return ()
--    >>> runSecureClient "echo.websocket.org" 443 "/" app
runSecureClient ::
  (MonadIO.MonadIO m) =>
  (Catch.MonadMask m) =>
  -- | Host
  Socket.HostName ->
  -- | Port
  Socket.PortNumber ->
  -- | Path
  String.String ->
  -- | Application
  WebSockets.ClientApp a ->
  m a
runSecureClient :: forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
HostName -> PortNumber -> HostName -> ClientApp a -> m a
runSecureClient HostName
host PortNumber
port HostName
path ClientApp a
app = do
  let options :: ConnectionOptions
options = ConnectionOptions
WebSockets.defaultConnectionOptions
  forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
HostName
-> PortNumber
-> HostName
-> ConnectionOptions
-> Headers
-> ClientApp a
-> m a
runSecureClientWith HostName
host PortNumber
port HostName
path ConnectionOptions
options [] ClientApp a
app

-- |
--    A secure replacement for 'Network.WebSockets.runClientWith'.
--
--    >>> let options = defaultConnectionOptions
--    >>> let headers = []
--    >>> let app _connection = return ()
--    >>> runSecureClientWith "echo.websocket.org" 443 "/" options headers app
--
--    If you want to run a secure client without certificate validation, use
--    'Network.WebSockets.runClientWithStream'. For example:
--
--    > let host = "echo.websocket.org"
--    > let port = 443
--    > let path = "/"
--    > let options = defaultConnectionOptions
--    > let headers = []
--    > let tlsSettings = TLSSettingsSimple
--    >     -- This is the important setting.
--    >     { settingDisableCertificateValidation = True
--    >     , settingDisableSession = False
--    >     , settingUseServerName = False
--    >     }
--    > let connectionParams = ConnectionParams
--    >     { connectionHostname = host
--    >     , connectionPort = port
--    >     , connectionUseSecure = Just tlsSettings
--    >     , connectionUseSocks = Nothing
--    >     }
--    >
--    > context <- initConnectionContext
--    > connection <- connectTo context connectionParams
--    > stream <- makeStream
--    >     (fmap Just (connectionGetChunk connection))
--    >     (maybe (return ()) (connectionPut connection . toStrict))
--    > runClientWithStream stream host path options headers $ \ connection -> do
--    >     -- Do something with the connection.
--    >     return ()
runSecureClientWith ::
  (MonadIO.MonadIO m) =>
  (Catch.MonadMask m) =>
  -- | Host
  Socket.HostName ->
  -- | Port
  Socket.PortNumber ->
  -- | Path
  String.String ->
  -- | Options
  WebSockets.ConnectionOptions ->
  -- | Headers
  WebSockets.Headers ->
  -- | Application
  WebSockets.ClientApp a ->
  m a
runSecureClientWith :: forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
HostName
-> PortNumber
-> HostName
-> ConnectionOptions
-> Headers
-> ClientApp a
-> m a
runSecureClientWith HostName
host PortNumber
port HostName
path ConnectionOptions
options Headers
headers ClientApp a
app = do
  let config :: Config
config = Config
defaultConfig
  forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
HostName
-> PortNumber
-> HostName
-> Config
-> ConnectionOptions
-> Headers
-> ClientApp a
-> m a
runSecureClientWithConfig HostName
host PortNumber
port HostName
path Config
config ConnectionOptions
options Headers
headers ClientApp a
app

-- | Configures a secure WebSocket connection.
newtype Config = Config
  { -- | How to get bytes from the connection. Typically
    -- 'Connection.connectionGetChunk', but could be something else like
    -- 'Connection.connectionGetLine'.
    Config -> Connection -> IO ByteString
connectionGet :: Connection.Connection -> IO.IO StrictBytes.ByteString
  }

-- | The default 'Config' value used by 'runSecureClientWith'.
defaultConfig :: Config
defaultConfig :: Config
defaultConfig = do
  Config {connectionGet :: Connection -> IO ByteString
connectionGet = Connection -> IO ByteString
Connection.connectionGetChunk}

-- | Runs a secure WebSockets client with the given 'Config'.
runSecureClientWithConfig ::
  (MonadIO.MonadIO m) =>
  (Catch.MonadMask m) =>
  -- | Host
  Socket.HostName ->
  -- | Port
  Socket.PortNumber ->
  -- | Path
  String.String ->
  -- | Config
  Config ->
  -- | Options
  WebSockets.ConnectionOptions ->
  -- | Headers
  WebSockets.Headers ->
  -- | Application
  WebSockets.ClientApp a ->
  m a
runSecureClientWithConfig :: forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
HostName
-> PortNumber
-> HostName
-> Config
-> ConnectionOptions
-> Headers
-> ClientApp a
-> m a
runSecureClientWithConfig HostName
host PortNumber
port HostName
path Config
config ConnectionOptions
options Headers
headers ClientApp a
app = do
  ConnectionContext
context <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
MonadIO.liftIO IO ConnectionContext
Connection.initConnectionContext
  forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
Catch.bracket
    (forall (m :: * -> *) a. MonadIO m => IO a -> m a
MonadIO.liftIO forall a b. (a -> b) -> a -> b
$ ConnectionContext -> ConnectionParams -> IO Connection
Connection.connectTo ConnectionContext
context (HostName -> PortNumber -> ConnectionParams
connectionParams HostName
host PortNumber
port))
    (forall (m :: * -> *) a. MonadIO m => IO a -> m a
MonadIO.liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO ()
Connection.connectionClose)
    ( \Connection
connection -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
MonadIO.liftIO forall a b. (a -> b) -> a -> b
$ do
        Stream
stream <-
          IO (Maybe ByteString) -> (Maybe ByteString -> IO ()) -> IO Stream
Stream.makeStream
            (Config -> Connection -> IO (Maybe ByteString)
reader Config
config Connection
connection)
            (Connection -> Maybe ByteString -> IO ()
writer Connection
connection)
        forall a.
Stream
-> HostName
-> HostName
-> ConnectionOptions
-> Headers
-> ClientApp a
-> IO a
WebSockets.runClientWithStream Stream
stream HostName
host HostName
path ConnectionOptions
options Headers
headers ClientApp a
app
    )

connectionParams ::
  Socket.HostName -> Socket.PortNumber -> Connection.ConnectionParams
connectionParams :: HostName -> PortNumber -> ConnectionParams
connectionParams HostName
host PortNumber
port = do
  Connection.ConnectionParams
    { connectionHostname :: HostName
Connection.connectionHostname = HostName
host,
      connectionPort :: PortNumber
Connection.connectionPort = PortNumber
port,
      connectionUseSecure :: Maybe TLSSettings
Connection.connectionUseSecure = forall a. a -> Maybe a
Maybe.Just TLSSettings
tlsSettings,
      connectionUseSocks :: Maybe ProxySettings
Connection.connectionUseSocks = forall a. Maybe a
Maybe.Nothing
    }

tlsSettings :: Connection.TLSSettings
tlsSettings :: TLSSettings
tlsSettings = do
  Connection.TLSSettingsSimple
    { settingDisableCertificateValidation :: Bool
Connection.settingDisableCertificateValidation = Bool
Bool.False,
      settingDisableSession :: Bool
Connection.settingDisableSession = Bool
Bool.False,
      settingUseServerName :: Bool
Connection.settingUseServerName = Bool
Bool.False
    }

reader ::
  Config ->
  Connection.Connection ->
  IO.IO (Maybe.Maybe StrictBytes.ByteString)
reader :: Config -> Connection -> IO (Maybe ByteString)
reader Config
config Connection
connection =
  forall a. IO a -> (IOError -> IO a) -> IO a
IO.Error.catchIOError
    ( do
        ByteString
chunk <- Config -> Connection -> IO ByteString
connectionGet Config
config Connection
connection
        forall (f :: * -> *) a. Applicative f => a -> f a
Applicative.pure (forall a. a -> Maybe a
Maybe.Just ByteString
chunk)
    )
    ( \IOError
e ->
        if IOError -> Bool
IO.Error.isEOFError IOError
e
          then forall (f :: * -> *) a. Applicative f => a -> f a
Applicative.pure forall a. Maybe a
Maybe.Nothing
          else forall e a. Exception e => e -> IO a
Exception.throwIO IOError
e
    )

writer ::
  Connection.Connection -> Maybe.Maybe LazyBytes.ByteString -> IO.IO ()
writer :: Connection -> Maybe ByteString -> IO ()
writer Connection
connection Maybe ByteString
maybeBytes = do
  case Maybe ByteString
maybeBytes of
    Maybe ByteString
Maybe.Nothing -> do
      forall (f :: * -> *) a. Applicative f => a -> f a
Applicative.pure ()
    Maybe.Just ByteString
bytes -> do
      Connection -> ByteString -> IO ()
Connection.connectionPut Connection
connection (ByteString -> ByteString
LazyBytes.toStrict ByteString
bytes)